From 19123ca957d95fe19133508a1de53e3b3c86d9c1 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Tue, 12 Aug 2025 06:20:39 -0400 Subject: [PATCH 01/45] refactor: standardize InferenceRouter model handling (#2965) --- llama_stack/apis/common/errors.py | 10 ++++ llama_stack/core/routers/inference.py | 49 ++++++------------- llama_stack/core/routing_tables/vector_dbs.py | 4 +- .../remote/inference/ollama/ollama.py | 3 -- 4 files changed, 28 insertions(+), 38 deletions(-) diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 95d6ac18e..6e0fa0b3c 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -62,3 +62,13 @@ class SessionNotFoundError(ValueError): def __init__(self, session_name: str) -> None: message = f"Session '{session_name}' not found or access denied." super().__init__(message) + + +class ModelTypeError(TypeError): + """raised when a model is present but not the correct type""" + + def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None: + message = ( + f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'" + ) + super().__init__(message) diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 79ab7c34f..52581cc9d 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.errors import ModelNotFoundError +from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -177,6 +177,15 @@ class InferenceRouter(Inference): encoded = self.formatter.encode_content(messages) return len(encoded.tokens) if encoded and encoded.tokens else 0 + async def _get_model(self, model_id: str, expected_model_type: str) -> Model: + """takes a model id and gets model after ensuring that it is accessible and of the correct type""" + model = await self.routing_table.get_model(model_id) + if model is None: + raise ModelNotFoundError(model_id) + if model.model_type != expected_model_type: + raise ModelTypeError(model_id, model.model_type, expected_model_type) + return model + async def chat_completion( self, model_id: str, @@ -195,11 +204,7 @@ class InferenceRouter(Inference): ) if sampling_params is None: sampling_params = SamplingParams() - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + model = await self._get_model(model_id, ModelType.llm) if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") @@ -301,11 +306,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", ) - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + model = await self._get_model(model_id, ModelType.llm) provider = await self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -355,11 +356,7 @@ class InferenceRouter(Inference): task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: logger.debug(f"InferenceRouter.embeddings: {model_id}") - model = await self.routing_table.get_model(model_id) - if model is None: - raise ModelNotFoundError(model_id) - if model.model_type == ModelType.llm: - raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") + await self._get_model(model_id, ModelType.embedding) provider = await self.routing_table.get_provider_impl(model_id) return await provider.embeddings( model_id=model_id, @@ -395,12 +392,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type == ModelType.embedding: - raise ValueError(f"Model '{model}' is an embedding model and does not support completions") - + model_obj = await self._get_model(model, ModelType.llm) params = dict( model=model_obj.identifier, prompt=prompt, @@ -476,11 +468,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type == ModelType.embedding: - raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") + model_obj = await self._get_model(model, ModelType.llm) # Use the OpenAI client for a bit of extra input validation without # exposing the OpenAI client itself as part of our API surface @@ -567,12 +555,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", ) - model_obj = await self.routing_table.get_model(model) - if model_obj is None: - raise ModelNotFoundError(model) - if model_obj.model_type != ModelType.embedding: - raise ValueError(f"Model '{model}' is not an embedding model") - + model_obj = await self._get_model(model, ModelType.embedding) params = dict( model=model_obj.identifier, input=input, diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index c81a27a3b..e8dc46997 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -8,7 +8,7 @@ from typing import Any from pydantic import TypeAdapter -from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError +from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs @@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if model is None: raise ModelNotFoundError(embedding_model) if model.model_type != ModelType.embedding: - raise ValueError(f"Model {embedding_model} is not an embedding model") + raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) if "embedding_dimension" not in model.metadata: raise ValueError(f"Model {embedding_model} does not have an embedding dimension") vector_db_data = { diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 26b4dec76..a93421536 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -457,9 +457,6 @@ class OllamaInferenceAdapter( user: str | None = None, ) -> OpenAIEmbeddingsResponse: model_obj = await self._get_model(model) - if model_obj.model_type != ModelType.embedding: - raise ValueError(f"Model {model} is not an embedding model") - if model_obj.provider_resource_id is None: raise ValueError(f"Model {model} has no provider_resource_id set") From 4a13ef45e984af8274b06e118af3baadf770a3bc Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Tue, 12 Aug 2025 17:32:52 +0200 Subject: [PATCH 02/45] fix: Implement missing `run_moderation` method in `PromptGuardSafetyImpl` (#3101) # What does this PR do? This PR addresses an issue where `PromptGuardSafetyImpl` was an incomplete implementation of an abstract class. The class was missing the required run_moderation method from its parent interface. Currently, running `pre-commit` locally fails with the error below. ``` llama_stack/providers/inline/safety/prompt_guard/__init__.py:15: error: Cannot instantiate abstract class "PromptGuardSafetyImpl" with abstract attribute "run_moderation" [abstract] Found 1 error in 1 file (checked 410 source files) ``` This PR fixes the issue as follows - Added the missing run_moderation method to PromptGuardSafetyImpl - Method raises NotImplementedError with appropriate message indicating this functionality is not implemented for PromptGuard - This allows the class to be properly instantiated while clearly indicating the limitation Signed-off-by: Mustafa Elbehery --- .../providers/inline/safety/prompt_guard/prompt_guard.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index e11ec5cf5..801500dee 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -64,6 +64,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await self.shield.run(messages) + async def run_moderation(self, input: str | list[str], model: str): + raise NotImplementedError("run_moderation not implemented for PromptGuard") + class PromptGuardShield: def __init__( From b70e2f1f09bae55603674e56de46a62608ee588e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 12 Aug 2025 10:40:32 -0500 Subject: [PATCH 03/45] fix(dep): update to openai >= 1.99.6 and use new Function location (#3087) # What does this PR do? closes #3072 ## Test Plan ci --- llama_stack/providers/utils/inference/openai_compat.py | 2 +- pyproject.toml | 2 +- uv.lock | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e6e5ccc8a..9a77c8cc4 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -70,7 +70,7 @@ from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_content_part_image_param import ( ImageURL as OpenAIImageURL, ) -from openai.types.chat.chat_completion_message_tool_call_param import ( +from openai.types.chat.chat_completion_message_tool_call import ( Function as OpenAIFunction, ) from pydantic import BaseModel diff --git a/pyproject.toml b/pyproject.toml index a77ec5ac9..1b0850631 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "jsonschema", "llama-stack-client>=0.2.17", "llama-api-client>=0.1.2", - "openai>=1.66", + "openai>=1.99.6", "prompt-toolkit", "python-dotenv", "python-jose[cryptography]", diff --git a/uv.lock b/uv.lock index c10a7962c..f57b5a161 100644 --- a/uv.lock +++ b/uv.lock @@ -1674,7 +1674,7 @@ requires-dist = [ { name = "llama-api-client", specifier = ">=0.1.2" }, { name = "llama-stack-client", specifier = ">=0.2.17" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.17" }, - { name = "openai", specifier = ">=1.66" }, + { name = "openai", specifier = ">=1.99.6" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" }, { name = "pandas", marker = "extra == 'ui'" }, @@ -2301,7 +2301,7 @@ wheels = [ [[package]] name = "openai" -version = "1.98.0" +version = "1.99.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -2313,9 +2313,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d8/9d/52eadb15c92802711d6b6cf00df3a6d0d18b588f4c5ba5ff210c6419fc03/openai-1.98.0.tar.gz", hash = "sha256:3ee0fcc50ae95267fd22bd1ad095ba5402098f3df2162592e68109999f685427", size = 496695, upload-time = "2025-07-30T12:48:03.701Z" } +sdist = { url = "https://files.pythonhosted.org/packages/11/45/38a87bd6949236db5ae3132f41d5861824702b149f86d2627d6900919103/openai-1.99.6.tar.gz", hash = "sha256:f48f4239b938ef187062f3d5199a05b69711d8b600b9a9b6a3853cd271799183", size = 505364, upload-time = "2025-08-09T15:20:54.438Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/fe/f64631075b3d63a613c0d8ab761d5941631a470f6fa87eaaee1aa2b4ec0c/openai-1.98.0-py3-none-any.whl", hash = "sha256:b99b794ef92196829120e2df37647722104772d2a74d08305df9ced5f26eae34", size = 767713, upload-time = "2025-07-30T12:48:01.264Z" }, + { url = "https://files.pythonhosted.org/packages/d6/dd/9aa956485c2856346b3181542fbb0aea4e5b457fa7a523944726746da8da/openai-1.99.6-py3-none-any.whl", hash = "sha256:e40d44b2989588c45ce13819598788b77b8fb80ba2f7ae95ce90d14e46f1bd26", size = 786296, upload-time = "2025-08-09T15:20:51.95Z" }, ] [[package]] From 393f3714b0ea9b71c425ff932510dea4709ea1f7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Aug 2025 08:44:24 -0700 Subject: [PATCH 04/45] chore(python-deps): bump torch from 2.7.1 to 2.8.0 (#3082) Bumps [torch](https://github.com/pytorch/pytorch) from 2.7.1 to 2.8.0.
Release notes

Sourced from torch's releases.

PyTorch 2.8.0 Release Notes

Highlights

... (truncated)

Commits
  • ba56102 Cherrypick: Add the RunLLM widget to the website (#159592)
  • c525a02 [dynamo, docs] cherry pick torch.compile programming model docs into 2.8 (#15...
  • a1cb3cc [Release Only] Remove nvshmem from list of preload libraries (#158925)
  • c76b235 Move out super large one off foreach_copy test (#158880)
  • 20a0e22 Revert "[Dynamo] Allow inlining into AO quantization modules (#152934)" (#158...
  • 9167ac8 [MPS] Switch Cholesky decomp to column wise (#158237)
  • 5534685 [MPS] Reimplement tri[ul] as Metal shaders (#158867)
  • d19e08d Cherry pick PR 158746 (#158801)
  • a6c044a [cherry-pick] Unify torch.tensor and torch.ops.aten.scalar_tensor behavior (#...
  • 620ebd0 [Dynamo] Use proper sources for constructing dataclass defaults (#158689)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=torch&package-manager=uv&previous-version=2.7.1&new-version=2.8.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 75 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/uv.lock b/uv.lock index f57b5a161..caafa1197 100644 --- a/uv.lock +++ b/uv.lock @@ -1632,10 +1632,10 @@ test = [ { name = "pypdf" }, { name = "requests" }, { name = "sqlalchemy", extra = ["asyncio"] }, - { name = "torch", version = "2.7.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, - { name = "torch", version = "2.7.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin'" }, - { name = "torchvision", version = "0.22.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "torchvision", version = "0.22.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "torch", version = "2.8.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.8.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform != 'darwin'" }, + { name = "torchvision", version = "0.23.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "torchvision", version = "0.23.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "transformers" }, { name = "weaviate-client" }, ] @@ -4310,7 +4310,7 @@ wheels = [ [[package]] name = "torch" -version = "2.7.1" +version = "2.8.0" source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "python_full_version >= '3.13' and sys_platform == 'darwin'", @@ -4326,14 +4326,14 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'darwin'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:7b4f8b2b83bd08f7d399025a9a7b323bdbb53d20566f1e0d584689bb92d82f9a" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:95af97e7b2cecdc89edc0558962a51921bf9c61538597dbec6b7cc48d31e2e13" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7ecd868a086468e1bcf74b91db425c1c2951a9cfcd0592c4c73377b7e42485ae" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:a47b7986bee3f61ad217d8a8ce24605809ab425baf349f97de758815edd2ef54" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:fbe2e149c5174ef90d29a5f84a554dfaf28e003cb4f61fa2c8c024c17ec7ca58" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:057efd30a6778d2ee5e2374cd63a63f63311aa6f33321e627c655df60abdd390" }, ] [[package]] name = "torch" -version = "2.7.1+cpu" +version = "2.8.0+cpu" source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -4351,21 +4351,24 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform != 'darwin'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3bf2db5adf77b433844f080887ade049c4705ddf9fe1a32023ff84ff735aa5ad" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:8f8b3cfc53010a4b4a3c7ecb88c212e9decc4f5eeb6af75c3c803937d2d60947" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:0bc887068772233f532b51a3e8c8cfc682ae62bef74bf4e0c53526c8b9e4138f" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:a2618775f32eb4126c5b2050686da52001a08cffa331637d9cf51c8250931e00" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:eb17646792ac4374ffc87e42369f45d21eff17c790868963b90483ef0b6db4ef" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:84ea1f6a1d15663037d01b121d6e33bb9da3c90af8e069e5072c30f413455a57" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:b66f77f6f67317344ee083aa7ac4751a14395fcb38060d564bf513978d267153" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:56136a2aca6707df3c8811e46ea2d379eaafd18e656e2fd51e8e4d0ca995651b" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:355614185a2aea7155f9c88a20bfd49de5f3063866f3cf9b2f21b6e9e59e31e0" }, - { url = "https://download.pytorch.org/whl/cpu/torch-2.7.1%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:464bca1bc9452f2ccd676514688896e66b9488f2a0268ecd3ac497cf09c5aac1" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-linux_s390x.whl", hash = "sha256:0e34e276722ab7dd0dffa9e12fe2135a9b34a0e300c456ed7ad6430229404eb5" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:610f600c102386e581327d5efc18c0d6edecb9820b4140d26163354a99cd800d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:cb9a8ba8137ab24e36bf1742cb79a1294bd374db570f09fc15a5e1318160db4e" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:2be20b2c05a0cce10430cc25f32b689259640d273232b2de357c35729132256d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp312-cp312-win_arm64.whl", hash = "sha256:99fc421a5d234580e45957a7b02effbf3e1c884a5dd077afc85352c77bf41434" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-linux_s390x.whl", hash = "sha256:8b5882276633cf91fe3d2d7246c743b94d44a7e660b27f1308007fdb1bb89f7d" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a5064b5e23772c8d164068cc7c12e01a75faf7b948ecd95a0d4007d7487e5f25" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:8f81dedb4c6076ec325acc3b47525f9c550e5284a18eae1d9061c543f7b6e7de" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:e1ee1b2346ade3ea90306dfbec7e8ff17bc220d344109d189ae09078333b0856" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313-win_arm64.whl", hash = "sha256:64c187345509f2b1bb334feed4666e2c781ca381874bde589182f81247e61f88" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:af81283ac671f434b1b25c95ba295f270e72db1fad48831eb5e4748ff9840041" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:a9dbb6f64f63258bc811e2c0c99640a81e5af93c531ad96e95c5ec777ea46dab" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:6d93a7165419bc4b2b907e859ccab0dea5deeab261448ae9a5ec5431f14c0e64" }, ] [[package]] name = "torchvision" -version = "0.22.1" +version = "0.23.0" source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", @@ -4376,21 +4379,21 @@ resolution-markers = [ dependencies = [ { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "torch", version = "2.7.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, - { name = "torch", version = "2.7.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "torch", version = "2.8.0", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.8.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:153f1790e505bd6da123e21eee6e83e2e155df05c0fe7d56347303067d8543c5" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:964414eef19459d55a10e886e2fca50677550e243586d1678f65e3f6f6bac47a" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c3ae3319624c43cc8127020f46c14aa878406781f0899bb6283ae474afeafbf" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:4a614a6a408d2ed74208d0ea6c28a2fbb68290e9a7df206c5fef3f0b6865d307" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:043d9e35ed69c2e586aff6eb9e2887382e7863707115668ac9d140da58f42cba" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:27142bcc8a984227a6dcf560985e83f52b82a7d3f5fe9051af586a2ccc46ef26" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e0e2c04a91403e8dd3af9756c6a024a1d9c0ed9c0d592a8314ded8f4fe30d440" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6dd7c4d329a0e03157803031bc856220c6155ef08c26d4f5bbac938acecf0948" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1c37e325e09a184b730c3ef51424f383ec5745378dc0eca244520aca29722600" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2f7fd6c15f3697e80627b77934f77705f3bc0e98278b989b2655de01f6903e1d" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2df618e1143805a7673aaf82cb5720dd9112d4e771983156aaf2ffff692eebf9" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a3299d2b1d5a7aed2d3b6ffb69c672ca8830671967eb1cee1497bacd82fe47b" }, ] [[package]] name = "torchvision" -version = "0.22.1+cpu" +version = "0.23.0+cpu" source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -4399,15 +4402,15 @@ resolution-markers = [ dependencies = [ { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "pillow", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "torch", version = "2.7.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "torch", version = "2.8.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b5fa7044bd82c6358e8229351c98070cf3a7bf4a6e89ea46352ae6c65745ef94" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:433cb4dbced7291f17064cea08ac1e5aebd02ec190e1c207d117ad62a8961f2b" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a93c21f18c33a819616b3dda7655aa4de40b219682c654175b6bbeb65ecc2e5f" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:34c914ad4728b81848ac802c5fc5eeb8de8ff4058cc59c1463a74ce4f4fbf0d8" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:ab7ae82529887c704c1b5d1d5198f65dc777d04fc3858b374503a6deedb82b19" }, - { url = "https://download.pytorch.org/whl/cpu/torchvision-0.22.1%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:b2d1c4bdbfd8e6c779dc810a6171b56224f1332fc46986810d4081bed1633804" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ae459d4509d3b837b978dc6c66106601f916b6d2cda75c137e3f5f48324ce1da" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:a651ccc540cf4c87eb988730c59c2220c52b57adc276f044e7efb9830fa65a1d" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0%2Bcpu-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:dea90a67d60a5366b0358a0b8d6bf267805278697d6fd950cf0e31139e56d1be" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0%2Bcpu-cp313-cp313-win_amd64.whl", hash = "sha256:82928788025170c62e7df1120dcdc0cd175bfc31c08374613ce6d1a040bc0cda" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0%2Bcpu-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:474d77adbbbed5166db3e5636b4b4ae3399c66ef5bfa12536e254b32259c90c0" }, + { url = "https://download.pytorch.org/whl/cpu/torchvision-0.23.0%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:8d6a47e23d7896f0ef9aa7ea7179eb6324e82438aa66d19884c2020d0646b104" }, ] [[package]] From 88c4fdc5d7fecd0468815f9eda25d72f722745a6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Aug 2025 08:44:39 -0700 Subject: [PATCH 05/45] chore(python-deps): bump chromadb from 1.0.15 to 1.0.16 (#3083) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [chromadb](https://github.com/chroma-core/chroma) from 1.0.15 to 1.0.16.
Release notes

Sourced from chromadb's releases.

1.0.16

Version: 1.0.16 Git ref: refs/tags/1.0.16 Build Date: 2025-08-08T00:26 PIP Package: chroma-1.0.16.tar.gz Github Container Registry Image: :1.0.16 DockerHub Image: :1.0.16

What's Changed

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=chromadb&package-manager=uv&previous-version=1.0.15&new-version=1.0.16)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- uv.lock | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/uv.lock b/uv.lock index caafa1197..9f4ba4adb 100644 --- a/uv.lock +++ b/uv.lock @@ -476,7 +476,7 @@ wheels = [ [[package]] name = "chromadb" -version = "1.0.15" +version = "1.0.16" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "bcrypt" }, @@ -507,13 +507,13 @@ dependencies = [ { name = "typing-extensions" }, { name = "uvicorn", extra = ["standard"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ad/e2/0653b2e539db5512d2200c759f1bc7f9ef5609fe47f3c7d24b82f62dc00f/chromadb-1.0.15.tar.gz", hash = "sha256:3e910da3f5414e2204f89c7beca1650847f2bf3bd71f11a2e40aad1eb31050aa", size = 1218840, upload-time = "2025-07-02T17:07:09.875Z" } +sdist = { url = "https://files.pythonhosted.org/packages/15/2a/5b7e793d2a27c425e9f1813e9cb965b70e9bda08b69ee15a10e07dc3e59a/chromadb-1.0.16.tar.gz", hash = "sha256:3c864b5beb5e131bdc1f83c0b63a01ec481c6ee52028f088563ffba8478478e1", size = 1241545, upload-time = "2025-08-08T00:25:41.414Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/85/5a/866c6f0c2160cbc8dca0cf77b2fb391dcf435b32a58743da1bc1a08dc442/chromadb-1.0.15-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:51791553014297798b53df4e043e9c30f4e8bd157647971a6bb02b04bfa65f82", size = 18838820, upload-time = "2025-07-02T17:07:07.632Z" }, - { url = "https://files.pythonhosted.org/packages/e1/18/ff9b58ab5d334f5ecff7fdbacd6761bac467176708fa4d2500ae7c048af0/chromadb-1.0.15-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:48015803c0631c3a817befc276436dc084bb628c37fd4214047212afb2056291", size = 18057131, upload-time = "2025-07-02T17:07:05.15Z" }, - { url = "https://files.pythonhosted.org/packages/31/49/74e34cc5aeeb25aff2c0ede6790b3671e14c1b91574dd8f98d266a4c5aad/chromadb-1.0.15-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b73cd6fb32fcdd91c577cca16ea6112b691d72b441bb3f2140426d1e79e453a", size = 18595284, upload-time = "2025-07-02T17:06:59.102Z" }, - { url = "https://files.pythonhosted.org/packages/cb/33/190df917a057067e37f8b48d082d769bed8b3c0c507edefc7b6c6bb577d0/chromadb-1.0.15-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:479f1b401af9e7c20f50642ffb3376abbfd78e2b5b170429f7c79eff52e367db", size = 19526626, upload-time = "2025-07-02T17:07:02.163Z" }, - { url = "https://files.pythonhosted.org/packages/a1/30/6890da607358993f87a01e80bcce916b4d91515ce865f07dc06845cb472f/chromadb-1.0.15-cp39-abi3-win_amd64.whl", hash = "sha256:e0cb3b93fdc42b1786f151d413ef36299f30f783a30ce08bf0bfb12e552b4190", size = 19520490, upload-time = "2025-07-02T17:07:11.559Z" }, + { url = "https://files.pythonhosted.org/packages/a3/9d/bffcc814272c9b7982551803b2d45b77f39eeea1b9e965c00c05ee81c649/chromadb-1.0.16-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:144163ce7ca4f4448684d5d0c13ebb37c4d68490ecb60967a95d05cea30e0d2d", size = 18942157, upload-time = "2025-08-08T00:25:38.459Z" }, + { url = "https://files.pythonhosted.org/packages/58/4e/de0086f3cbcfd667d75d112bb546386803ab5335599bf7099272a675e98b/chromadb-1.0.16-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:4ebcc5894e6fbb6b576452bbf4659746bfe58d9daf99a18363364e9497434bd2", size = 18147831, upload-time = "2025-08-08T00:25:35.546Z" }, + { url = "https://files.pythonhosted.org/packages/0e/7f/a8aff4ce96281bcb9731d10b2554f41963dd0b47acb4f90a78b2b7c4f199/chromadb-1.0.16-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:937051fc3aae94f7c171503d8f1f7662820aacc75acf45f28d3656c75c5ff1f8", size = 18682195, upload-time = "2025-08-08T00:25:29.654Z" }, + { url = "https://files.pythonhosted.org/packages/a3/9c/2a97d0257176aae472dff6f1ef1b7050449f384e420120e0f31d2d8f532f/chromadb-1.0.16-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0f5c5ad0c59154a9cab1506b857bab8487b588352e668cf1222c54bb9d52daa", size = 19635695, upload-time = "2025-08-08T00:25:32.68Z" }, + { url = "https://files.pythonhosted.org/packages/96/8a/f7e810f3cbdc9186ba4e649dc32711b7ab2c23aba37cf61175f731d22293/chromadb-1.0.16-cp39-abi3-win_amd64.whl", hash = "sha256:2528c01bd8b3facca9d0e1ffac866767c386b94604df484fc792ee891c86e09a", size = 19641144, upload-time = "2025-08-08T00:25:43.446Z" }, ] [[package]] From 6812aa1e1e6aad77706f1e063fd4ed6603cf9871 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Tue, 12 Aug 2025 11:52:57 -0400 Subject: [PATCH 06/45] chore: bump min python version in docs and tests (#3103) # What does this PR do? the minimum python version for the project was bumped to 3.12 a couple months ago, but there remains some artifacts in the repo suggesting we support >=3.10 Signed-off-by: Nathan Weinberg --- docs/source/apis/external.md | 4 ++-- docs/source/providers/external/external-providers-guide.md | 2 +- tests/external/llama-stack-api-weather/pyproject.toml | 2 +- tests/external/llama-stack-provider-kaze/pyproject.toml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/apis/external.md b/docs/source/apis/external.md index cc13deb9b..5831990b0 100644 --- a/docs/source/apis/external.md +++ b/docs/source/apis/external.md @@ -111,7 +111,7 @@ name = "llama-stack-api-weather" version = "0.1.0" description = "Weather API for Llama Stack" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic"] [build-system] @@ -231,7 +231,7 @@ name = "llama-stack-provider-kaze" version = "0.1.0" description = "Kaze weather provider for Llama Stack" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic", "aiohttp"] [build-system] diff --git a/docs/source/providers/external/external-providers-guide.md b/docs/source/providers/external/external-providers-guide.md index 2479d406f..e2d4ebea9 100644 --- a/docs/source/providers/external/external-providers-guide.md +++ b/docs/source/providers/external/external-providers-guide.md @@ -226,7 +226,7 @@ uv init name = "llama-stack-provider-ollama" version = "0.1.0" description = "Ollama provider for Llama Stack" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"] ``` diff --git a/tests/external/llama-stack-api-weather/pyproject.toml b/tests/external/llama-stack-api-weather/pyproject.toml index 566e1e9aa..ac2d8d632 100644 --- a/tests/external/llama-stack-api-weather/pyproject.toml +++ b/tests/external/llama-stack-api-weather/pyproject.toml @@ -3,7 +3,7 @@ name = "llama-stack-api-weather" version = "0.1.0" description = "Weather API for Llama Stack" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic"] [build-system] diff --git a/tests/external/llama-stack-provider-kaze/pyproject.toml b/tests/external/llama-stack-provider-kaze/pyproject.toml index 7bbf1f843..e2438a18a 100644 --- a/tests/external/llama-stack-provider-kaze/pyproject.toml +++ b/tests/external/llama-stack-provider-kaze/pyproject.toml @@ -3,7 +3,7 @@ name = "llama-stack-provider-kaze" version = "0.1.0" description = "Kaze weather provider for Llama Stack" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.12" dependencies = ["llama-stack", "pydantic", "aiohttp"] [build-system] From 4fec49dfdb4ecb9a11411ae8a672f63b4c7cca58 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Aug 2025 10:24:01 -0700 Subject: [PATCH 07/45] feat(responses): add include parameter (#3115) Well our Responses tests use it so we better include it in the API, no? I discovered it because I want to make sure `llama-stack-client` can be used always instead of `openai-python` as the client (we do want to be _truly_ compatible.) --- docs/_static/llama-stack-spec.html | 7 +++++++ docs/_static/llama-stack-spec.yaml | 6 ++++++ llama_stack/apis/agents/agents.py | 2 ++ .../inline/agents/meta_reference/agents.py | 13 ++++++++++++- .../agents/meta_reference/openai_responses.py | 1 + 5 files changed, 28 insertions(+), 1 deletion(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index d480ff592..a16d3fce5 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -8515,6 +8515,13 @@ "$ref": "#/components/schemas/OpenAIResponseInputTool" } }, + "include": { + "type": "array", + "items": { + "type": "string" + }, + "description": "(Optional) Additional fields to include in the response." + }, "max_infer_iters": { "type": "integer" } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 9c0fba554..d5ad66d5e 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6188,6 +6188,12 @@ components: type: array items: $ref: '#/components/schemas/OpenAIResponseInputTool' + include: + type: array + items: + type: string + description: >- + (Optional) Additional fields to include in the response. max_infer_iters: type: integer additionalProperties: false diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index e816da766..7dd3e9289 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -706,6 +706,7 @@ class Agents(Protocol): temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, max_infer_iters: int | None = 10, # this is an extension to the OpenAI API ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a new OpenAI response. @@ -713,6 +714,7 @@ class Agents(Protocol): :param input: Input message(s) to create the response. :param model: The underlying LLM used for completions. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. + :param include: (Optional) Additional fields to include in the response. :returns: An OpenAIResponseObject. """ ... diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 15695ec48..0f12a0865 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -327,10 +327,21 @@ class MetaReferenceAgentsImpl(Agents): temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, max_infer_iters: int | None = 10, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( - input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters + input, + model, + instructions, + previous_response_id, + store, + stream, + temperature, + text, + tools, + include, + max_infer_iters, ) async def list_openai_responses( diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 7eb2b3897..db70bc046 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -333,6 +333,7 @@ class OpenAIResponsesImpl: temperature: float | None = None, text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, max_infer_iters: int | None = 10, ): stream = bool(stream) From 1721aafc1fc13297d7db6c3bcb0c65344c6a7cd0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Aug 2025 10:39:09 -0700 Subject: [PATCH 08/45] feat(responses): type file results properly (#3117) Another thing our tests implicitly depended on. --- docs/_static/llama-stack-spec.html | 74 +++++++++++++------ docs/_static/llama-stack-spec.yaml | 46 ++++++++++-- llama_stack/apis/agents/openai_responses.py | 19 ++++- .../agents/meta_reference/openai_responses.py | 14 ++-- 4 files changed, 117 insertions(+), 36 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index a16d3fce5..e2c53d4b0 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -8293,28 +8293,60 @@ "type": "array", "items": { "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" + "properties": { + "attributes": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } + "description": "(Optional) Key-value attributes associated with the file" + }, + "file_id": { + "type": "string", + "description": "Unique identifier of the file containing the result" + }, + "filename": { + "type": "string", + "description": "Name of the file containing the result" + }, + "score": { + "type": "number", + "description": "Relevance score for this search result (between 0 and 1)" + }, + "text": { + "type": "string", + "description": "Text content of the search result" + } + }, + "additionalProperties": false, + "required": [ + "attributes", + "file_id", + "filename", + "score", + "text" + ], + "title": "OpenAIResponseOutputMessageFileSearchToolCallResults", + "description": "Search results returned by the file search operation." }, "description": "(Optional) Search results returned by the file search operation" } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index d5ad66d5e..85cec3a78 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6021,14 +6021,44 @@ components: type: array items: type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + properties: + attributes: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + description: >- + (Optional) Key-value attributes associated with the file + file_id: + type: string + description: >- + Unique identifier of the file containing the result + filename: + type: string + description: Name of the file containing the result + score: + type: number + description: >- + Relevance score for this search result (between 0 and 1) + text: + type: string + description: Text content of the search result + additionalProperties: false + required: + - attributes + - file_id + - filename + - score + - text + title: >- + OpenAIResponseOutputMessageFileSearchToolCallResults + description: >- + Search results returned by the file search operation. description: >- (Optional) Search results returned by the file search operation additionalProperties: false diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 10cadf38f..8574104dc 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -170,6 +170,23 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): type: Literal["web_search_call"] = "web_search_call" +class OpenAIResponseOutputMessageFileSearchToolCallResults(BaseModel): + """Search results returned by the file search operation. + + :param attributes: (Optional) Key-value attributes associated with the file + :param file_id: Unique identifier of the file containing the result + :param filename: Name of the file containing the result + :param score: Relevance score for this search result (between 0 and 1) + :param text: Text content of the search result + """ + + attributes: dict[str, Any] + file_id: str + filename: str + score: float + text: str + + @json_schema_type class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): """File search tool call output message for OpenAI responses. @@ -185,7 +202,7 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): queries: list[str] status: str type: Literal["file_search_call"] = "file_search_call" - results: list[dict[str, Any]] | None = None + results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None @json_schema_type diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index db70bc046..b98ca114f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -38,6 +38,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageFileSearchToolCall, + OpenAIResponseOutputMessageFileSearchToolCallResults, OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageMCPListTools, OpenAIResponseOutputMessageWebSearchToolCall, @@ -827,12 +828,13 @@ class OpenAIResponsesImpl: text = result.metadata["chunks"][i] if "chunks" in result.metadata else None score = result.metadata["scores"][i] if "scores" in result.metadata else None message.results.append( - { - "file_id": doc_id, - "filename": doc_id, - "text": text, - "score": score, - } + OpenAIResponseOutputMessageFileSearchToolCallResults( + file_id=doc_id, + filename=doc_id, + text=text, + score=score, + attributes={}, + ) ) if error_exc or (result.error_code and result.error_code > 0) or result.error_message: message.status = "failed" From 3d901178914fefcedeccb8467aaed014fa8b637a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Aug 2025 16:15:53 -0700 Subject: [PATCH 09/45] chore(tests): fix responses and vector_io tests (#3119) Some fixes to MCP tests. And a bunch of fixes for Vector providers. I also enabled a bunch of Vector IO tests to be used with `LlamaStackLibraryClient` ## Test Plan Run Responses tests with llama stack library client: ``` pytest -s -v tests/integration/non_ci/responses/ --stack-config=server:starter \ --text-model openai/gpt-4o \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ -k "client_with_models" ``` Do the same with `-k openai_client` The rest should be taken care of by CI. --- .github/actions/setup-runner/action.yml | 2 +- .github/workflows/integration-tests.yml | 3 +- .../workflows/integration-vector-io-tests.yml | 4 +- llama_stack/core/build.py | 2 +- llama_stack/core/routers/inference.py | 3 +- llama_stack/core/routing_tables/models.py | 2 + llama_stack/log.py | 1 + .../models/llama/llama3/chat_format.py | 1 + .../agents/meta_reference/openai_responses.py | 4 ++ .../providers/inline/vector_io/faiss/faiss.py | 19 ++++--- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 25 +++++---- llama_stack/providers/registry/vector_io.py | 4 ++ .../remote/inference/fireworks/fireworks.py | 3 + .../remote/vector_io/chroma/chroma.py | 17 ++++-- .../remote/vector_io/milvus/milvus.py | 16 +++--- .../remote/vector_io/pgvector/pgvector.py | 12 ++-- .../remote/vector_io/qdrant/qdrant.py | 18 +++--- .../remote/vector_io/weaviate/weaviate.py | 11 ++-- .../utils/memory/openai_vector_store_mixin.py | 32 ++++++++--- .../providers/utils/memory/vector_store.py | 15 ++++- tests/common/mcp.py | 12 +--- tests/integration/fixtures/common.py | 2 +- .../fixtures/test_cases/responses.yaml | 6 +- .../non_ci/responses/test_responses.py | 17 ++++-- .../vector_io/test_openai_vector_stores.py | 56 ++++++++----------- 25 files changed, 175 insertions(+), 112 deletions(-) diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml index 0be999fe2..1ca02bbff 100644 --- a/.github/actions/setup-runner/action.yml +++ b/.github/actions/setup-runner/action.yml @@ -28,7 +28,7 @@ runs: # Install llama-stack-client-python based on the client-version input if [ "${{ inputs.client-version }}" = "latest" ]; then echo "Installing latest llama-stack-client-python from main branch" - uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main + uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main elif [ "${{ inputs.client-version }}" = "published" ]; then echo "Installing published llama-stack-client-python from PyPI" uv pip install llama-stack-client diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index a38d4971a..9ef49fba3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -52,7 +52,8 @@ jobs: run: | # Get test directories dynamically, excluding non-test directories # NOTE: we are excluding post_training since the tests take too long - TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | + TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d | + sed 's|tests/integration/||' | grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | sort | jq -R -s -c 'split("\n")[:-1]') echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index aa239572b..f4d28e407 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -164,9 +164,9 @@ jobs: ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} run: | - uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ + uv run pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ - --embedding-model sentence-transformers/all-MiniLM-L6-v2 + --embedding-model inline::sentence-transformers/all-MiniLM-L6-v2 - name: Check Storage and Memory Available After Tests if: ${{ always() }} diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index b3e35ecef..4b20588fd 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -91,7 +91,7 @@ def get_provider_dependencies( def print_pip_install_help(config: BuildConfig): - normal_deps, special_deps = get_provider_dependencies(config) + normal_deps, special_deps, _ = get_provider_dependencies(config) cprint( f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}", diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 52581cc9d..6a3f07247 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.telemetry.tracing import get_current_span -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="inference") class InferenceRouter(Inference): @@ -854,4 +854,5 @@ class InferenceRouter(Inference): model=model.identifier, object="chat.completion", ) + logger.debug(f"InferenceRouter.completion_response: {final_response}") await self.store.store_chat_completion(final_response, messages) diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index c76619271..34c431e00 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -63,6 +63,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def get_provider_impl(self, model_id: str) -> Any: model = await lookup_model(self, model_id) + if model.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider {model.provider_id} not found in the routing table") return self.impls_by_provider_id[model.provider_id] async def register_model( diff --git a/llama_stack/log.py b/llama_stack/log.py index 0a2d63ef6..7507aface 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -32,6 +32,7 @@ CATEGORIES = [ "tools", "client", "telemetry", + "openai_responses", ] # Initialize category levels with default level diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 0a973cf0c..1f88a1699 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -236,6 +236,7 @@ class ChatFormat: arguments_json=json.dumps(tool_arguments), ) ) + content = "" return RawMessage( role="assistant", diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index b98ca114f..347954908 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -488,8 +488,12 @@ class OpenAIResponsesImpl: # Convert collected chunks to complete response if chat_response_tool_calls: tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] + + # when there are tool calls, we need to clear the content + chat_response_content = [] else: tool_calls = None + assistant_message = OpenAIAssistantMessageParam( content="".join(chat_response_content), tool_calls=tool_calls, diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 5a063592c..af61da59b 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -33,6 +33,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -128,11 +129,12 @@ class FaissIndex(EmbeddingIndex): # Save updated index await self._save_index() - async def delete_chunk(self, chunk_id: str) -> None: - if chunk_id not in self.chunk_ids: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + chunk_ids = [c.chunk_id for c in chunks_for_deletion] + if not set(chunk_ids).issubset(self.chunk_ids): return - async with self.chunk_id_lock: + def remove_chunk(chunk_id: str): index = self.chunk_ids.index(chunk_id) self.index.remove_ids(np.array([index])) @@ -146,6 +148,10 @@ class FaissIndex(EmbeddingIndex): self.chunk_by_index = new_chunk_by_index self.chunk_ids.pop(index) + async with self.chunk_id_lock: + for chunk_id in chunk_ids: + remove_chunk(chunk_id) + await self._save_index() async def query_vector( @@ -297,8 +303,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a faiss index""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a faiss index""" faiss_index = self.cache[store_id].index - for chunk_id in chunk_ids: - await faiss_index.delete_chunk(chunk_id) + await faiss_index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 1fff7b484..cc1982f3b 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -31,6 +31,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIV from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_RRF, RERANKER_TYPE_WEIGHTED, + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -426,34 +427,36 @@ class SQLiteVecIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the SQLite vector store.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] - def _delete_chunk(): + def _delete_chunks(): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() try: cur.execute("BEGIN TRANSACTION") # Delete from metadata table - cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,)) + placeholders = ",".join("?" * len(chunk_ids)) + cur.execute(f"DELETE FROM {self.metadata_table} WHERE id IN ({placeholders})", chunk_ids) # Delete from vector table - cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,)) + cur.execute(f"DELETE FROM {self.vector_table} WHERE id IN ({placeholders})", chunk_ids) # Delete from FTS table - cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,)) + cur.execute(f"DELETE FROM {self.fts_table} WHERE id IN ({placeholders})", chunk_ids) connection.commit() except Exception as e: connection.rollback() - logger.error(f"Error deleting chunk {chunk_id}: {e}") + logger.error(f"Error deleting chunks: {e}") raise finally: cur.close() connection.close() - await asyncio.to_thread(_delete_chunk) + await asyncio.to_thread(_delete_chunks) class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): @@ -551,12 +554,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a sqlite_vec index.""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a sqlite_vec index.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index ed170b508..70148eb15 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -342,6 +342,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -350,6 +351,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti module="llama_stack.providers.inline.vector_io.chroma", config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig", api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], description=""" [Chroma](https://www.trychroma.com/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. @@ -464,6 +466,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -731,6 +734,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index ca4c7b578..bd86f7238 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -235,6 +235,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): + # TODO: tools are never added to the request, so we need to add them here if media_present or not llama_model: input_dict["messages"] = [ await convert_message_to_openai_dict(m, download=True) for m in request.messages @@ -378,6 +379,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv # Fireworks chat completions OpenAI-compatible API does not support # tool calls properly. llama_model = self.get_llama_model(model_obj.provider_resource_id) + if llama_model: return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion( self, @@ -431,4 +433,5 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv user=user, ) + logger.debug(f"fireworks params: {params}") return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 26aeaedfb..8f252711b 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -26,6 +26,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -115,8 +116,10 @@ class ChromaIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Chroma") - async def delete_chunk(self, chunk_id: str) -> None: - raise NotImplementedError("delete_chunk is not supported in Chroma") + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete a single chunk from the Chroma collection by its ID.""" + ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion] + await maybe_await(self.collection.delete(ids=ids)) async def query_hybrid( self, @@ -144,6 +147,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.cache = {} self.kvstore: KVStore | None = None self.vector_db_store = None + self.files_api = files_api async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) @@ -227,5 +231,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.cache[vector_db_id] = index return index - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a Chroma vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index b09edb65c..0eaae81b3 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -28,6 +28,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( RERANKER_TYPE_WEIGHTED, + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -287,14 +288,17 @@ class MilvusIndex(EmbeddingIndex): return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the Milvus collection.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] try: + # Use IN clause with square brackets and single quotes for VARCHAR field + chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) await asyncio.to_thread( - self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"' + self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]" ) except Exception as e: - logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}") + logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") raise @@ -420,12 +424,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete a chunk from a milvus vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index b1645ac5a..d2a5d910b 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -163,10 +164,11 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the PostgreSQL table.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,)) + cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,)) class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): @@ -275,12 +277,10 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) return self.cache[vector_db_id] - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete a chunk from a PostgreSQL vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise VectorStoreNotFoundError(store_id) - for chunk_id in chunk_ids: - # Use the index's delete_chunk method - await index.index.delete_chunk(chunk_id) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 144da0f4f..018015780 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -29,6 +29,7 @@ from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig a from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -88,15 +89,16 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Remove a chunk from the Qdrant collection.""" + chunk_ids = [convert_id(c.chunk_id) for c in chunks_for_deletion] try: await self.client.delete( collection_name=self.collection_name, - points_selector=models.PointIdsList(points=[convert_id(chunk_id)]), + points_selector=models.PointIdsList(points=chunk_ids), ) except Exception as e: - log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}") + log.error(f"Error deleting chunks from Qdrant collection {self.collection_name}: {e}") raise async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: @@ -264,12 +266,14 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) -> VectorStoreFileObject: # Qdrant doesn't allow multiple clients to access the same storage path simultaneously. async with self._qdrant_lock: - await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy) + return await super().openai_attach_file_to_vector_store( + vector_store_id, file_id, attributes, chunking_strategy + ) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete chunks from a Qdrant vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: raise ValueError(f"Vector DB {store_id} not found") - for chunk_id in chunk_ids: - await index.index.delete_chunk(chunk_id) + + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 11da8902c..966724848 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -26,6 +26,7 @@ from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( OpenAIVectorStoreMixin, ) from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, EmbeddingIndex, VectorDBWithIndex, ) @@ -67,6 +68,7 @@ class WeaviateIndex(EmbeddingIndex): data_objects.append( wvc.data.DataObject( properties={ + "chunk_id": chunk.chunk_id, "chunk_content": chunk.model_dump_json(), }, vector=embeddings[i].tolist(), @@ -79,10 +81,11 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) - async def delete_chunk(self, chunk_id: str) -> None: + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) collection = self.client.collections.get(sanitized_collection_name) - collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id])) + chunk_ids = [chunk.chunk_id for chunk in chunks_for_deletion] + collection.data.delete_many(where=Filter.by_property("chunk_id").contains_any(chunk_ids)) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) @@ -307,10 +310,10 @@ class WeaviateVectorIOAdapter( return await index.query_chunks(query, params) - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True) index = await self._get_and_cache_vector_db_index(sanitized_collection_name) if not index: raise ValueError(f"Vector DB {sanitized_collection_name} not found") - await index.delete(chunk_ids) + await index.index.delete_chunks(chunks_for_deletion) diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 7b6e69df1..120d0d4fc 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -6,7 +6,6 @@ import asyncio import json -import logging import mimetypes import time import uuid @@ -37,10 +36,15 @@ from llama_stack.apis.vector_io import ( VectorStoreSearchResponse, VectorStoreSearchResponsePage, ) +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore.api import KVStore -from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks +from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, + content_from_data_and_mime_type, + make_overlapped_chunks, +) -logger = logging.getLogger(__name__) +logger = get_logger(__name__, category="vector_io") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 @@ -154,8 +158,8 @@ class OpenAIVectorStoreMixin(ABC): self.openai_vector_stores = await self._load_openai_vector_stores() @abstractmethod - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - """Delete a chunk from a vector store.""" + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from a vector store.""" pass @abstractmethod @@ -614,7 +618,7 @@ class OpenAIVectorStoreMixin(ABC): ) vector_store_file_object.status = "completed" except Exception as e: - logger.error(f"Error attaching file to vector store: {e}") + logger.exception("Error attaching file to vector store") vector_store_file_object.status = "failed" vector_store_file_object.last_error = VectorStoreFileLastError( code="server_error", @@ -767,7 +771,21 @@ class OpenAIVectorStoreMixin(ABC): dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) chunks = [Chunk.model_validate(c) for c in dict_chunks] - await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id]) + + # Create ChunkForDeletion objects with both chunk_id and document_id + chunks_for_deletion = [] + for c in chunks: + if c.chunk_id: + document_id = c.metadata.get("document_id") or ( + c.chunk_metadata.document_id if c.chunk_metadata else None + ) + if document_id: + chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id)) + else: + logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion") + + if chunks_for_deletion: + await self.delete_chunks(vector_store_id, chunks_for_deletion) store_info = self.openai_vector_stores[vector_store_id].copy() diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index bb9002f30..6ae5bb521 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -16,6 +16,7 @@ from urllib.parse import unquote import httpx import numpy as np from numpy.typing import NDArray +from pydantic import BaseModel from llama_stack.apis.common.content_types import ( URL, @@ -34,6 +35,18 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id log = logging.getLogger(__name__) + +class ChunkForDeletion(BaseModel): + """Information needed to delete a chunk from a vector store. + + :param chunk_id: The ID of the chunk to delete + :param document_id: The ID of the document this chunk belongs to + """ + + chunk_id: str + document_id: str + + # Constants for reranker types RERANKER_TYPE_RRF = "rrf" RERANKER_TYPE_WEIGHTED = "weighted" @@ -232,7 +245,7 @@ class EmbeddingIndex(ABC): raise NotImplementedError() @abstractmethod - async def delete_chunk(self, chunk_id: str): + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]): raise NotImplementedError() @abstractmethod diff --git a/tests/common/mcp.py b/tests/common/mcp.py index 775e38295..d05ac39c6 100644 --- a/tests/common/mcp.py +++ b/tests/common/mcp.py @@ -16,13 +16,10 @@ MCP_TOOLGROUP_ID = "mcp::localmcp" def default_tools(): """Default tools for backward compatibility.""" - from mcp import types from mcp.server.fastmcp import Context - async def greet_everyone( - url: str, ctx: Context - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - return [types.TextContent(type="text", text="Hello, world!")] + async def greet_everyone(url: str, ctx: Context) -> str: + return "Hello, world!" async def get_boiling_point(liquid_name: str, celsius: bool = True) -> int: """ @@ -45,7 +42,6 @@ def default_tools(): def dependency_tools(): """Tools with natural dependencies for multi-turn testing.""" - from mcp import types from mcp.server.fastmcp import Context async def get_user_id(username: str, ctx: Context) -> str: @@ -106,7 +102,7 @@ def dependency_tools(): else: access = "no" - return [types.TextContent(type="text", text=access)] + return access async def get_experiment_id(experiment_name: str, ctx: Context) -> str: """ @@ -245,7 +241,6 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal try: yield {"server_url": server_url} finally: - print("Telling SSE server to exit") server_instance.should_exit = True time.sleep(0.5) @@ -269,4 +264,3 @@ def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Cal AppStatus.should_exit = False AppStatus.should_exit_event = None - print("SSE server exited") diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index c91391f19..0b7132d71 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -270,7 +270,7 @@ def openai_client(client_with_models): @pytest.fixture(params=["openai_client", "client_with_models"]) def compat_client(request, client_with_models): - if isinstance(client_with_models, LlamaStackAsLibraryClient): + if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient): # OpenAI client expects a server, so unless we also rewrite OpenAI client's requests # to go via the Stack library client (which itself rewrites requests to be served inline), # we cannot do this. diff --git a/tests/integration/non_ci/responses/fixtures/test_cases/responses.yaml b/tests/integration/non_ci/responses/fixtures/test_cases/responses.yaml index 6db0dd970..353a64291 100644 --- a/tests/integration/non_ci/responses/fixtures/test_cases/responses.yaml +++ b/tests/integration/non_ci/responses/fixtures/test_cases/responses.yaml @@ -137,7 +137,7 @@ test_response_multi_turn_tool_execution: server_url: "" output: "yes" - case_id: "experiment_results_lookup" - input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found." + input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me the boiling point in Celsius." tools: - type: mcp server_label: "localmcp" @@ -149,7 +149,7 @@ test_response_multi_turn_tool_execution_streaming: test_params: case: - case_id: "user_permissions_workflow" - input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step." + input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response." tools: - type: mcp server_label: "localmcp" @@ -157,7 +157,7 @@ test_response_multi_turn_tool_execution_streaming: stream: true output: "no" - case_id: "experiment_analysis_streaming" - input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process." + input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Return only one tool call per step. Please stream your analysis process." tools: - type: mcp server_label: "localmcp" diff --git a/tests/integration/non_ci/responses/test_responses.py b/tests/integration/non_ci/responses/test_responses.py index 4f4f27d7f..39d00f328 100644 --- a/tests/integration/non_ci/responses/test_responses.py +++ b/tests/integration/non_ci/responses/test_responses.py @@ -363,6 +363,9 @@ def test_response_non_streaming_file_search_empty_vector_store(request, compat_c ids=case_id_generator, ) def test_response_non_streaming_mcp_tool(request, compat_client, text_model_id, case): + if not isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("in-process MCP server is only supported in library client") + with make_mcp_server() as mcp_server_info: tools = case["tools"] for tool in tools: @@ -485,8 +488,11 @@ def test_response_non_streaming_multi_turn_image(request, compat_client, text_mo responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"], ids=case_id_generator, ) -def test_response_non_streaming_multi_turn_tool_execution(request, compat_client, text_model_id, case): +def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" + if not isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("in-process MCP server is only supported in library client") + with make_mcp_server(tools=dependency_tools()) as mcp_server_info: tools = case["tools"] # Replace the placeholder URL with the actual server URL @@ -541,8 +547,11 @@ def test_response_non_streaming_multi_turn_tool_execution(request, compat_client responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"], ids=case_id_generator, ) -async def test_response_streaming_multi_turn_tool_execution(request, compat_client, text_model_id, case): +def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): """Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" + if not isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("in-process MCP server is only supported in library client") + with make_mcp_server(tools=dependency_tools()) as mcp_server_info: tools = case["tools"] # Replace the placeholder URL with the actual server URL @@ -634,7 +643,7 @@ async def test_response_streaming_multi_turn_tool_execution(request, compat_clie }, ], ) -def test_response_text_format(request, compat_client, text_model_id, text_format): +def test_response_text_format(compat_client, text_model_id, text_format): if isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("Responses API text format is not yet supported in library client.") @@ -653,7 +662,7 @@ def test_response_text_format(request, compat_client, text_model_id, text_format @pytest.fixture -def vector_store_with_filtered_files(request, compat_client, text_model_id, tmp_path_factory): +def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_factory): """Create a vector store with multiple files that have different attributes for filtering tests.""" if isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("Responses API file search is not yet supported in library client.") diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 3212a7568..7ccca9077 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -9,10 +9,11 @@ import time from io import BytesIO import pytest -from llama_stack_client import BadRequestError, LlamaStackClient +from llama_stack_client import BadRequestError from openai import BadRequestError as OpenAIBadRequestError from llama_stack.apis.vector_io import Chunk +from llama_stack.core.library_client import LlamaStackAsLibraryClient logger = logging.getLogger(__name__) @@ -475,9 +476,6 @@ def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client """Test OpenAI vector store attach file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores # Create a vector store @@ -526,9 +524,6 @@ def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_s """Test OpenAI vector store attach files on creation.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores # Create some files and attach them to the vector store @@ -582,9 +577,6 @@ def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_ """Test OpenAI vector store list files.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores # Create a vector store @@ -597,16 +589,20 @@ def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_ file_buffer.name = f"openai_test_{i}.txt" file = compat_client.files.create(file=file_buffer, purpose="assistants") - compat_client.vector_stores.files.create( + response = compat_client.vector_stores.files.create( vector_store_id=vector_store.id, file_id=file.id, ) + assert response is not None + assert response.status == "completed", ( + f"Failed to attach file {file.id} to vector store {vector_store.id}: {response=}" + ) file_ids.append(file.id) files_list = compat_client.vector_stores.files.list(vector_store_id=vector_store.id) assert files_list assert files_list.object == "list" - assert files_list.data + assert files_list.data is not None assert not files_list.has_more assert len(files_list.data) == 3 assert set(file_ids) == {file.id for file in files_list.data} @@ -642,12 +638,13 @@ def test_openai_vector_store_list_files_invalid_vector_store(compat_client_with_ """Test OpenAI vector store list files with invalid vector store ID.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores + if isinstance(compat_client, LlamaStackAsLibraryClient): + errors = ValueError + else: + errors = (BadRequestError, OpenAIBadRequestError) - with pytest.raises((BadRequestError, OpenAIBadRequestError)): + with pytest.raises(errors): compat_client.vector_stores.files.list(vector_store_id="abc123") @@ -655,9 +652,6 @@ def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_sto """Test OpenAI vector store retrieve file contents.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files retrieve contents is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores # Create a vector store @@ -685,9 +679,15 @@ def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_sto file_id=file.id, ) - assert file_contents - assert file_contents.content[0]["type"] == "text" - assert file_contents.content[0]["text"] == test_content.decode("utf-8") + assert file_contents is not None + assert len(file_contents.content) == 1 + content = file_contents.content[0] + + # llama-stack-client returns a model, openai-python is a badboy and returns a dict + if not isinstance(content, dict): + content = content.model_dump() + assert content["type"] == "text" + assert content["text"] == test_content.decode("utf-8") assert file_contents.filename == file_name assert file_contents.attributes == attributes @@ -696,9 +696,6 @@ def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client """Test OpenAI vector store delete file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores # Create a vector store @@ -751,9 +748,6 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client """Test OpenAI vector store delete file removes from vector store.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores # Create a vector store @@ -792,9 +786,6 @@ def test_openai_vector_store_update_file(compat_client_with_empty_stores, client """Test OpenAI vector store update file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files update is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores # Create a vector store @@ -840,9 +831,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(compat_client_wit """ skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - if isinstance(compat_client_with_empty_stores, LlamaStackClient): - pytest.skip("Vector Store Files create is not yet supported with LlamaStackClient") - compat_client = compat_client_with_empty_stores # Create a vector store with files From 6358d0a47832185c5484ede73d97fff450db10f5 Mon Sep 17 00:00:00 2001 From: Kelly Brown <86735520+kelbrown20@users.noreply.github.com> Date: Tue, 12 Aug 2025 19:17:03 -0400 Subject: [PATCH 10/45] docs: reorganize contributor guide (#3110) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description:** Restructures contribution guide and move some sections into categories Screenshot 2025-08-12 at 9 28 44 AM --- CONTRIBUTING.md | 171 ++++++++++++++---------------- docs/source/contributing/index.md | 22 ++-- 2 files changed, 92 insertions(+), 101 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 066fcecf0..c81e9e7b1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,13 +1,82 @@ -# Contributing to Llama-Stack +# Contributing to Llama Stack We want to make contributing to this project as easy and transparent as possible. +## Set up your development environment + +We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments. +You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/). + +You can install the dependencies by running: + +```bash +cd llama-stack +uv sync --group dev +uv pip install -e . +source .venv/bin/activate +``` + +```{note} +You can use a specific version of Python with `uv` by adding the `--python ` flag (e.g. `--python 3.12`). +Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`. +For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/). +``` + +Note that you can create a dotenv file `.env` that includes necessary environment variables: +``` +LLAMA_STACK_BASE_URL=http://localhost:8321 +LLAMA_STACK_CLIENT_LOG=debug +LLAMA_STACK_PORT=8321 +LLAMA_STACK_CONFIG= +TAVILY_SEARCH_API_KEY= +BRAVE_SEARCH_API_KEY= +``` + +And then use this dotenv file when running client SDK tests via the following: +```bash +uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct +``` + +### Pre-commit Hooks + +We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running: + +```bash +uv run pre-commit install +``` + +After that, pre-commit hooks will run automatically before each commit. + +Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running: + +```bash +uv run pre-commit run --all-files +``` + +```{caution} +Before pushing your changes, make sure that the pre-commit hooks have passed successfully. +``` + ## Discussions -> Issues -> Pull Requests We actively welcome your pull requests. However, please read the following. This is heavily inspired by [Ghostty](https://github.com/ghostty-org/ghostty/blob/main/CONTRIBUTING.md). If in doubt, please open a [discussion](https://github.com/meta-llama/llama-stack/discussions); we can always convert that to an issue later. +### Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +### Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + **I'd like to contribute!** If you are new to the project, start by looking at the issues tagged with "good first issue". If you're interested @@ -51,93 +120,15 @@ Please avoid picking up too many issues at once. This helps you stay focused and Please keep pull requests (PRs) small and focused. If you have a large set of changes, consider splitting them into logically grouped, smaller PRs to facilitate review and testing. -> [!TIP] -> As a general guideline: -> - Experienced contributors should try to keep no more than 5 open PRs at a time. -> - New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process. - -## Contributor License Agreement ("CLA") -In order to accept your pull request, we need you to submit a CLA. You only need -to do this once to work on any of Meta's open source projects. - -Complete your CLA here: - -## Issues -We use GitHub issues to track public bugs. Please ensure your description is -clear and has sufficient instructions to be able to reproduce the issue. - -Meta has a [bounty program](http://facebook.com/whitehat/info) for the safe -disclosure of security bugs. In those cases, please go through the process -outlined on that page and do not file a public issue. - - -## Set up your development environment - -We use [uv](https://github.com/astral-sh/uv) to manage python dependencies and virtual environments. -You can install `uv` by following this [guide](https://docs.astral.sh/uv/getting-started/installation/). - -You can install the dependencies by running: - -```bash -cd llama-stack -uv sync --group dev -uv pip install -e . -source .venv/bin/activate +```{tip} +As a general guideline: +- Experienced contributors should try to keep no more than 5 open PRs at a time. +- New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process. ``` -> [!NOTE] -> You can use a specific version of Python with `uv` by adding the `--python ` flag (e.g. `--python 3.12`) -> Otherwise, `uv` will automatically select a Python version according to the `requires-python` section of the `pyproject.toml`. -> For more info, see the [uv docs around Python versions](https://docs.astral.sh/uv/concepts/python-versions/). +## Repository guidelines -Note that you can create a dotenv file `.env` that includes necessary environment variables: -``` -LLAMA_STACK_BASE_URL=http://localhost:8321 -LLAMA_STACK_CLIENT_LOG=debug -LLAMA_STACK_PORT=8321 -LLAMA_STACK_CONFIG= -TAVILY_SEARCH_API_KEY= -BRAVE_SEARCH_API_KEY= -``` - -And then use this dotenv file when running client SDK tests via the following: -```bash -uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py --text-model=meta-llama/Llama-3.1-8B-Instruct -``` - -## Pre-commit Hooks - -We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running: - -```bash -uv run pre-commit install -``` - -After that, pre-commit hooks will run automatically before each commit. - -Alternatively, if you don't want to install the pre-commit hooks, you can run the checks manually by running: - -```bash -uv run pre-commit run --all-files -``` - -> [!CAUTION] -> Before pushing your changes, make sure that the pre-commit hooks have passed successfully. - -## Running tests - -You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md). - -## Adding a new dependency to the project - -To add a new dependency to the project, you can use the `uv` command. For example, to add `foo` to the project, you can run: - -```bash -uv add foo -uv sync -``` - -## Coding Style +### Coding Style * Comments should provide meaningful insights into the code. Avoid filler comments that simply describe the next step, as they create unnecessary clutter, same goes for docstrings. @@ -159,6 +150,10 @@ uv sync * When possible, use keyword arguments only when calling functions. * Llama Stack utilizes [custom Exception classes](llama_stack/apis/common/errors.py) for certain Resources that should be used where applicable. +### License +By contributing to Llama, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. + ## Common Tasks Some tips about common tasks you work on while contributing to Llama Stack: @@ -210,8 +205,4 @@ If you modify or add new API endpoints, update the API documentation accordingly uv run ./docs/openapi_generator/run_openapi_generator.sh ``` -The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. - -## License -By contributing to Llama, you agree that your contributions will be licensed -under the LICENSE file in the root directory of this source tree. +The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. \ No newline at end of file diff --git a/docs/source/contributing/index.md b/docs/source/contributing/index.md index 79c3861ea..7a3a1c2e2 100644 --- a/docs/source/contributing/index.md +++ b/docs/source/contributing/index.md @@ -2,17 +2,6 @@ ```{include} ../../../CONTRIBUTING.md ``` -## Testing - -See the [Test Page](testing.md) which describes how to test your changes. -```{toctree} -:maxdepth: 1 -:hidden: -:caption: Testing - -testing -``` - ## Adding a New Provider See the [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack. @@ -27,3 +16,14 @@ See the [External Provider Page](../providers/external/index.md) which describes new_api_provider new_vector_database ``` + +## Testing + +See the [Test Page](testing.md) which describes how to test your changes. +```{toctree} +:maxdepth: 1 +:hidden: +:caption: Testing + +testing +``` \ No newline at end of file From fffdab4f5c23b4ac045b898bdb1e69edda3a3498 Mon Sep 17 00:00:00 2001 From: Chacksu Date: Wed, 13 Aug 2025 09:18:25 -0400 Subject: [PATCH 11/45] fix: Dell distribution missing kvstore (#3113) # What does this PR do? - Added kvstore config to ChromaDB provider config for Dell distribution similar to [starter config](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/distributions/starter/run.yaml#L110-L112) - Fixed [error](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/inference/_generated/_async_client.py#L3424-L3425) getting endpoint information by adding `hf-inference` as the provider to the `AsyncInferenceClient` (TGI client). ## Test Plan ``` export INFERENCE_PORT=8181 export DEH_URL=http://0.0.0.0:$INFERENCE_PORT export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export CHROMADB_HOST=localhost export CHROMADB_PORT=8000 export CHROMA_URL=http://$CHROMADB_HOST:$CHROMADB_PORT export CUDA_VISIBLE_DEVICES=0 export LLAMA_STACK_PORT=8321 export HF_TOKEN=[redacted] # TGI Server docker run --rm -it \ --pull always \ --network host \ -v $HOME/.cache/huggingface:/data \ -e HF_TOKEN=$HF_TOKEN \ -e PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ -p $INFERENCE_PORT:$INFERENCE_PORT \ --gpus all \ ghcr.io/huggingface/text-generation-inference:latest \ --dtype float16 \ --usage-stats off \ --sharded false \ --cuda-memory-fraction 0.8 \ --model-id meta-llama/Llama-3.2-3B-Instruct \ --port $INFERENCE_PORT \ --hostname 0.0.0.0 # Chrome DB docker run --rm -it \ --name chromadb \ --net=host -p 8000:8000 \ -v ~/chroma:/chroma/chroma \ -e IS_PERSISTENT=TRUE \ -e ANONYMIZED_TELEMETRY=FALSE \ chromadb/chroma:latest # Llama Stack llama stack run dell \ --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ --env DEH_URL=$DEH_URL \ --env CHROMA_URL=$CHROMA_URL ``` --------- Co-authored-by: Connor Hack Co-authored-by: Ashwin Bharambe --- llama_stack/distributions/dell/dell.py | 8 +++++--- llama_stack/distributions/dell/run-with-safety.yaml | 5 ++++- llama_stack/distributions/dell/run.yaml | 5 ++++- llama_stack/providers/remote/inference/tgi/tgi.py | 4 +--- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/llama_stack/distributions/dell/dell.py b/llama_stack/distributions/dell/dell.py index b561ea00e..e3bf0ee03 100644 --- a/llama_stack/distributions/dell/dell.py +++ b/llama_stack/distributions/dell/dell.py @@ -16,6 +16,7 @@ from llama_stack.distributions.template import DistributionTemplate, RunConfigSe from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) +from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig def get_distribution_template() -> DistributionTemplate: @@ -71,9 +72,10 @@ def get_distribution_template() -> DistributionTemplate: chromadb_provider = Provider( provider_id="chromadb", provider_type="remote::chromadb", - config={ - "url": "${env.CHROMA_URL}", - }, + config=ChromaVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}/", + url="${env.CHROMADB_URL:=}", + ), ) inference_model = ModelInput( diff --git a/llama_stack/distributions/dell/run-with-safety.yaml b/llama_stack/distributions/dell/run-with-safety.yaml index ecc6729eb..d89c92aa1 100644 --- a/llama_stack/distributions/dell/run-with-safety.yaml +++ b/llama_stack/distributions/dell/run-with-safety.yaml @@ -26,7 +26,10 @@ providers: - provider_id: chromadb provider_type: remote::chromadb config: - url: ${env.CHROMA_URL} + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/dell/run.yaml b/llama_stack/distributions/dell/run.yaml index fc2553526..7397410ba 100644 --- a/llama_stack/distributions/dell/run.yaml +++ b/llama_stack/distributions/dell/run.yaml @@ -22,7 +22,10 @@ providers: - provider_id: chromadb provider_type: remote::chromadb config: - url: ${env.CHROMA_URL} + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell/}/chroma_remote_registry.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index a5bb079ef..323831845 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -308,9 +308,7 @@ class TGIAdapter(_HfAdapter): if not config.url: raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.") log.info(f"Initializing TGI client with url={config.url}") - self.client = AsyncInferenceClient( - model=config.url, - ) + self.client = AsyncInferenceClient(model=config.url, provider="hf-inference") endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] From 5bd6cb52fb9dd5a1a4defaeef0fe881cce59efdd Mon Sep 17 00:00:00 2001 From: Krzysztof Malczuk <2000krzysztof@gmail.com> Date: Wed, 13 Aug 2025 15:14:03 +0100 Subject: [PATCH 12/45] fix: github action canceling valid tasks for checking semantic pr title (#3127) # What does this PR do? This PR changes the group name from github.ref to github.even.pull_request_number. The reason for this is that github.ref does not act as a unique identifier in the pull_request_target event and only is unique in pull_request. The github action was getting canceled was because the group name was not unique in the concurrency section. Closes #3102 ## Test Plan To test this I have created a fake github action and ran it trough act to see what the github.ref variable produced and what alternatives can be used. This confirmed that the github.ref was not unique and that github.event.pull_request_number is unique to the PR. --- .github/workflows/semantic-pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 4df7324c4..57a4df646 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -11,7 +11,7 @@ on: - synchronize concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} cancel-in-progress: true permissions: From 92aca434a72cc0170c275e857781ebbce1e9e709 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Wed, 13 Aug 2025 08:46:26 -0600 Subject: [PATCH 13/45] fix: Fix list_sessions() (#3114) # What does this PR do? 1. Updates `AgentPersistence.list_sessions()` to properly filter out `Turn` keys from `Session` keys. 2. Adds a suite of unit tests to confirm the `list_sessions()` behavior and tests the failed sample in https://github.com/meta-llama/llama-stack/issues/3048 ## Fixes https://github.com/meta-llama/llama-stack/issues/3048 ## Test Plan Unit tests added. --------- Signed-off-by: Francisco Javier Arceo --- .../agents/meta_reference/persistence.py | 6 +- .../agent/test_agent_meta_reference.py | 347 ++++++++++++++++++ 2 files changed, 352 insertions(+), 1 deletion(-) create mode 100644 tests/unit/providers/agent/test_agent_meta_reference.py diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 7a8d99b78..0b234d96c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -191,7 +191,11 @@ class AgentPersistence: sessions = [] for value in values: try: - session_info = Session(**json.loads(value)) + data = json.loads(value) + if "turn_id" in data: + continue + + session_info = Session(**data) sessions.append(session_info) except Exception as e: log.error(f"Error parsing session info: {e}") diff --git a/tests/unit/providers/agent/test_agent_meta_reference.py b/tests/unit/providers/agent/test_agent_meta_reference.py new file mode 100644 index 000000000..3fc60024a --- /dev/null +++ b/tests/unit/providers/agent/test_agent_meta_reference.py @@ -0,0 +1,347 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import UTC, datetime +from unittest.mock import AsyncMock, patch + +import pytest + +from llama_stack.apis.agents import Session +from llama_stack.core.datatypes import User +from llama_stack.providers.inline.agents.meta_reference.persistence import ( + AgentPersistence, + AgentSessionInfo, +) +from llama_stack.providers.utils.kvstore import KVStore + + +@pytest.fixture +def mock_kvstore(): + return AsyncMock(spec=KVStore) + + +@pytest.fixture +def mock_policy(): + return [] + + +@pytest.fixture +def agent_persistence(mock_kvstore, mock_policy): + return AgentPersistence(agent_id="test-agent-123", kvstore=mock_kvstore, policy=mock_policy) + + +@pytest.fixture +def sample_session(): + return AgentSessionInfo( + session_id="session-123", + session_name="Test Session", + started_at=datetime.now(UTC), + owner=User(principal="user-123", attributes=None), + turns=[], + identifier="test-session", + type="session", + ) + + +@pytest.fixture +def sample_session_json(sample_session): + return sample_session.model_dump_json() + + +class TestAgentPersistenceListSessions: + def setup_mock_kvstore(self, mock_kvstore, session_keys=None, turn_keys=None, invalid_keys=None, custom_data=None): + """Helper to setup mock kvstore with sessions, turns, and custom/invalid data + + Args: + mock_kvstore: The mock KVStore object + session_keys: List of session keys or dict mapping keys to custom session data + turn_keys: List of turn keys or dict mapping keys to custom turn data + invalid_keys: Dict mapping keys to invalid/corrupt data + custom_data: Additional custom data to add to the mock responses + """ + all_keys = [] + mock_data = {} + + # session keys + if session_keys: + if isinstance(session_keys, dict): + all_keys.extend(session_keys.keys()) + mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in session_keys.items()}) + else: + all_keys.extend(session_keys) + for key in session_keys: + session_id = key.split(":")[-1] + mock_data[key] = json.dumps( + { + "session_id": session_id, + "session_name": f"Session {session_id}", + "started_at": datetime.now(UTC).isoformat(), + "turns": [], + } + ) + + # turn keys + if turn_keys: + if isinstance(turn_keys, dict): + all_keys.extend(turn_keys.keys()) + mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in turn_keys.items()}) + else: + all_keys.extend(turn_keys) + for key in turn_keys: + parts = key.split(":") + session_id = parts[-2] + turn_id = parts[-1] + mock_data[key] = json.dumps( + { + "turn_id": turn_id, + "session_id": session_id, + "input_messages": [], + "started_at": datetime.now(UTC).isoformat(), + } + ) + + if invalid_keys: + all_keys.extend(invalid_keys.keys()) + mock_data.update(invalid_keys) + + if custom_data: + mock_data.update(custom_data) + + values_list = list(mock_data.values()) + mock_kvstore.values_in_range.return_value = values_list + + async def mock_get(key): + return mock_data.get(key) + + mock_kvstore.get.side_effect = mock_get + + return mock_data + + @pytest.mark.parametrize( + "scenario", + [ + { + # from this issue: https://github.com/meta-llama/llama-stack/issues/3048 + "name": "reported_bug", + "session_keys": ["session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"], + "turn_keys": [ + "session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad" + ], + "expected_sessions": ["1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"], + }, + { + "name": "basic_filtering", + "session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"], + "turn_keys": ["session:test-agent-123:session-1:turn-1", "session:test-agent-123:session-1:turn-2"], + "expected_sessions": ["session-1", "session-2"], + }, + { + "name": "multiple_turns_per_session", + "session_keys": ["session:test-agent-123:session-456"], + "turn_keys": [ + "session:test-agent-123:session-456:turn-789", + "session:test-agent-123:session-456:turn-790", + ], + "expected_sessions": ["session-456"], + }, + { + "name": "multiple_sessions_with_turns", + "session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"], + "turn_keys": [ + "session:test-agent-123:session-1:turn-1", + "session:test-agent-123:session-1:turn-2", + "session:test-agent-123:session-2:turn-3", + ], + "expected_sessions": ["session-1", "session-2"], + }, + ], + ) + async def test_list_sessions_key_filtering(self, agent_persistence, mock_kvstore, scenario): + self.setup_mock_kvstore(mock_kvstore, session_keys=scenario["session_keys"], turn_keys=scenario["turn_keys"]) + + with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log: + result = await agent_persistence.list_sessions() + + assert len(result) == len(scenario["expected_sessions"]) + session_ids = {s.session_id for s in result} + for expected_id in scenario["expected_sessions"]: + assert expected_id in session_ids + + # no errors should be logged + mock_log.error.assert_not_called() + + @pytest.mark.parametrize( + "error_scenario", + [ + { + "name": "invalid_json", + "valid_keys": ["session:test-agent-123:valid-session"], + "invalid_data": {"session:test-agent-123:invalid-json": "corrupted-json-data{"}, + "expected_valid_sessions": ["valid-session"], + "expected_error_count": 1, + }, + { + "name": "missing_fields", + "valid_keys": ["session:test-agent-123:valid-session"], + "invalid_data": { + "session:test-agent-123:invalid-schema": json.dumps( + { + "session_id": "invalid-schema", + "session_name": "Missing Fields", + # missing `started_at` and `turns` + } + ) + }, + "expected_valid_sessions": ["valid-session"], + "expected_error_count": 1, + }, + { + "name": "multiple_invalid", + "valid_keys": ["session:test-agent-123:valid-session-1", "session:test-agent-123:valid-session-2"], + "invalid_data": { + "session:test-agent-123:corrupted-json": "not-valid-json{", + "session:test-agent-123:incomplete-data": json.dumps({"incomplete": "data"}), + }, + "expected_valid_sessions": ["valid-session-1", "valid-session-2"], + "expected_error_count": 2, + }, + ], + ) + async def test_list_sessions_error_handling(self, agent_persistence, mock_kvstore, error_scenario): + session_keys = {} + for key in error_scenario["valid_keys"]: + session_id = key.split(":")[-1] + session_keys[key] = { + "session_id": session_id, + "session_name": f"Valid {session_id}", + "started_at": datetime.now(UTC).isoformat(), + "turns": [], + } + + self.setup_mock_kvstore(mock_kvstore, session_keys=session_keys, invalid_keys=error_scenario["invalid_data"]) + + with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log: + result = await agent_persistence.list_sessions() + + # only valid sessions should be returned + assert len(result) == len(error_scenario["expected_valid_sessions"]) + session_ids = {s.session_id for s in result} + for expected_id in error_scenario["expected_valid_sessions"]: + assert expected_id in session_ids + + # error should be logged + assert mock_log.error.call_count > 0 + assert mock_log.error.call_count == error_scenario["expected_error_count"] + + async def test_list_sessions_empty(self, agent_persistence, mock_kvstore): + mock_kvstore.values_in_range.return_value = [] + + result = await agent_persistence.list_sessions() + + assert result == [] + mock_kvstore.values_in_range.assert_called_once_with( + start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff" + ) + + async def test_list_sessions_properties(self, agent_persistence, mock_kvstore): + session_data = { + "session_id": "session-123", + "session_name": "Test Session", + "started_at": datetime.now(UTC).isoformat(), + "owner": {"principal": "user-123", "attributes": None}, + "turns": [], + } + + self.setup_mock_kvstore(mock_kvstore, session_keys={"session:test-agent-123:session-123": session_data}) + + result = await agent_persistence.list_sessions() + + assert len(result) == 1 + assert isinstance(result[0], Session) + assert result[0].session_id == "session-123" + assert result[0].session_name == "Test Session" + assert result[0].turns == [] + assert hasattr(result[0], "started_at") + + async def test_list_sessions_kvstore_exception(self, agent_persistence, mock_kvstore): + mock_kvstore.values_in_range.side_effect = Exception("KVStore error") + + with pytest.raises(Exception, match="KVStore error"): + await agent_persistence.list_sessions() + + async def test_bug_data_loss_with_real_data(self, agent_persistence, mock_kvstore): + # tests the handling of the issue reported in: https://github.com/meta-llama/llama-stack/issues/3048 + session_data = { + "session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d", + "session_name": "Test Session", + "started_at": datetime.now(UTC).isoformat(), + "turns": [], + } + + turn_data = { + "turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad", + "session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d", + "input_messages": [ + {"role": "user", "content": "if i had a cluster i would want to call it persistence01", "context": None} + ], + "steps": [ + { + "turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad", + "step_id": "c0f797dd-3d34-4bc5-a8f4-db6af9455132", + "started_at": "2025-08-05T14:31:50.000484Z", + "completed_at": "2025-08-05T14:31:51.303691Z", + "step_type": "inference", + "model_response": { + "role": "assistant", + "content": "OK, I can create a cluster named 'persistence01' for you.", + "stop_reason": "end_of_turn", + "tool_calls": [], + }, + } + ], + "output_message": { + "role": "assistant", + "content": "OK, I can create a cluster named 'persistence01' for you.", + "stop_reason": "end_of_turn", + "tool_calls": [], + }, + "output_attachments": [], + "started_at": "2025-08-05T14:31:49.999950Z", + "completed_at": "2025-08-05T14:31:51.305384Z", + } + + mock_data = { + "session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d": json.dumps(session_data), + "session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad": json.dumps( + turn_data + ), + } + + mock_kvstore.values_in_range.return_value = list(mock_data.values()) + + async def mock_get(key): + return mock_data.get(key) + + mock_kvstore.get.side_effect = mock_get + + with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log: + result = await agent_persistence.list_sessions() + + assert len(result) == 1 + assert result[0].session_id == "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d" + + # confirm no errors logged + mock_log.error.assert_not_called() + + async def test_list_sessions_key_range_construction(self, agent_persistence, mock_kvstore): + mock_kvstore.values_in_range.return_value = [] + + await agent_persistence.list_sessions() + + mock_kvstore.values_in_range.assert_called_once_with( + start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff" + ) From c9b78602d3a99832dd85197dfb206f1439490945 Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Wed, 13 Aug 2025 15:56:26 +0100 Subject: [PATCH 14/45] refactor: modify DELETE API endpoints by returning HTTP 204 No Content + empty body instead of 200 OK + response body with null (#3112) # What does this PR do? The purpose of this PR is to make the behavior DELETE API endpoints be consistent with standard RESTful conventions and eliminate confusion for API consumers. Old Behavior ``` HTTP Status: 200 OK Response Body: null ``` Eg. `curl -X DELETE http://localhost:8321/v1/shields/test-shield` `null% ` `INFO 2025-08-12 16:11:57,932 console_span_processor:65 telemetry: 15:11:57.929 [INFO] ::1:59805 - "DELETE /v1/shields/test-shield HTTP/1.1" 200 ` Updated Behavior ``` HTTP Status: 204 No Content Response Body: empty (no body) ``` Eg. `curl -X DELETE http://localhost:8321/v1/shields/test-shield` `INFO 2025-08-12 16:18:16,645 console_span_processor:62 telemetry: 15:18:16.637 [INFO] ::1:60283 - "DELETE /v1/shields/test-shield HTTP/1.1" 204 ` Closes #3090 ## Test Plan Run `./scripts/unit-tests.sh` --- llama_stack/core/library_client.py | 11 ++++++++++- llama_stack/core/server/server.py | 7 ++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index 5fbbf1aff..a93fe509e 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -380,8 +380,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): json_content = json.dumps(convert_pydantic_to_json_value(result)) filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)} + + status_code = httpx.codes.OK + + if options.method.upper() == "DELETE" and result is None: + status_code = httpx.codes.NO_CONTENT + + if status_code == httpx.codes.NO_CONTENT: + json_content = "" + mock_response = httpx.Response( - status_code=httpx.codes.OK, + status_code=status_code, content=json_content.encode("utf-8"), headers={ "Content-Type": "application/json", diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index fe5cc68d7..61ad3e7b3 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -21,10 +21,11 @@ from importlib.metadata import version as parse_version from pathlib import Path from typing import Annotated, Any, get_origin +import httpx import rich.pretty import yaml from aiohttp import hdrs -from fastapi import Body, FastAPI, HTTPException, Request +from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse @@ -236,6 +237,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: result = await maybe_await(value) if isinstance(result, PaginatedResponse) and result.url is None: result.url = route + + if method.upper() == "DELETE" and result is None: + return Response(status_code=httpx.codes.NO_CONTENT) + return result except Exception as e: if logger.isEnabledFor(logging.DEBUG): From 0cbd93c5cc44b790c5b08a2f827944c9ac3223d7 Mon Sep 17 00:00:00 2001 From: Kelly Brown <86735520+kelbrown20@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:06:31 -0400 Subject: [PATCH 15/45] docs: Update blocks formatting in docs/source files (#3120) **Description:** The standard markdown [!NOTE] format is not supported on Sphinx generated documentation, replacing those instances. Also updating other Notes, Tips and Warning blocks throughout the source docs WIP: Working to update the provider code gen --- docs/source/building_applications/responses_vs_agents.md | 4 +++- docs/source/building_applications/tools.md | 4 +++- docs/source/providers/vector_io/inline_meta-reference.md | 4 +++- docs/source/providers/vector_io/inline_sqlite_vec.md | 4 +++- docs/source/providers/vector_io/remote_milvus.md | 5 ++++- .../source/references/llama_cli_reference/download_models.md | 4 +++- docs/source/references/llama_cli_reference/index.md | 4 +++- scripts/provider_codegen.py | 4 ++-- 8 files changed, 24 insertions(+), 9 deletions(-) diff --git a/docs/source/building_applications/responses_vs_agents.md b/docs/source/building_applications/responses_vs_agents.md index 3eebfb460..5abe951d6 100644 --- a/docs/source/building_applications/responses_vs_agents.md +++ b/docs/source/building_applications/responses_vs_agents.md @@ -2,7 +2,9 @@ Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics. -> **Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API. +```{note} +For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API. +``` ## Overview diff --git a/docs/source/building_applications/tools.md b/docs/source/building_applications/tools.md index b19be888c..8a54290ed 100644 --- a/docs/source/building_applications/tools.md +++ b/docs/source/building_applications/tools.md @@ -76,7 +76,9 @@ Features: - Context retrieval with token limits -> **Note:** By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers. +```{note} +By default, llama stack run.yaml defines toolgroups for web search, wolfram alpha and rag, that are provided by tavily-search, wolfram-alpha and rag providers. +``` ## Model Context Protocol (MCP) diff --git a/docs/source/providers/vector_io/inline_meta-reference.md b/docs/source/providers/vector_io/inline_meta-reference.md index 0aac445bd..6f269c441 100644 --- a/docs/source/providers/vector_io/inline_meta-reference.md +++ b/docs/source/providers/vector_io/inline_meta-reference.md @@ -21,5 +21,7 @@ kvstore: ## Deprecation Notice -⚠️ **Warning**: Please use the `inline::faiss` provider instead. +```{warning} +Please use the `inline::faiss` provider instead. +``` diff --git a/docs/source/providers/vector_io/inline_sqlite_vec.md b/docs/source/providers/vector_io/inline_sqlite_vec.md index 7ad8eb252..9e5654a50 100644 --- a/docs/source/providers/vector_io/inline_sqlite_vec.md +++ b/docs/source/providers/vector_io/inline_sqlite_vec.md @@ -25,5 +25,7 @@ kvstore: ## Deprecation Notice -⚠️ **Warning**: Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead. +```{warning} +Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead. +``` diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 2af64b8bb..075423d04 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -204,7 +204,10 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | -> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. +```{note} + This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. + ``` + ## Sample Configuration diff --git a/docs/source/references/llama_cli_reference/download_models.md b/docs/source/references/llama_cli_reference/download_models.md index e32099023..a9af65349 100644 --- a/docs/source/references/llama_cli_reference/download_models.md +++ b/docs/source/references/llama_cli_reference/download_models.md @@ -128,7 +128,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). -> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +```{tip} +Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +``` ## List the downloaded models diff --git a/docs/source/references/llama_cli_reference/index.md b/docs/source/references/llama_cli_reference/index.md index 4ef76fe7d..09a8b7177 100644 --- a/docs/source/references/llama_cli_reference/index.md +++ b/docs/source/references/llama_cli_reference/index.md @@ -152,7 +152,9 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). -> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +```{tip} +Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. +``` ## List the downloaded models diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index 84c45fe27..717677c52 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -187,7 +187,7 @@ def generate_provider_docs(provider_spec: Any, api_name: str) -> str: if config_info.get("accepts_extra_config"): md_lines.append( - "> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider." + "```{note}\n This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.\n ```\n" ) md_lines.append("") @@ -232,7 +232,7 @@ def generate_provider_docs(provider_spec: Any, api_name: str) -> str: if hasattr(provider_spec, "deprecation_warning") and provider_spec.deprecation_warning: md_lines.append("## Deprecation Notice") md_lines.append("") - md_lines.append(f"⚠️ **Warning**: {provider_spec.deprecation_warning}") + md_lines.append(f"```{{warning}}\n{provider_spec.deprecation_warning}\n```") md_lines.append("") if hasattr(provider_spec, "deprecation_error") and provider_spec.deprecation_error: From 0950168f26ec9489eb7ff6623d270ee903a7a07b Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Wed, 13 Aug 2025 16:43:41 +0100 Subject: [PATCH 16/45] refactor: replace hardcoded status codes by httpx.codes (#3131) # What does this PR do? The purpose of this PR is to eliminate hardcoded status codes in server's responses and replace it by `httpx.codes` functionality for better consistency across the whole project and improvement in code readability. ## Test Plan Run `./scripts/unit-tests.sh` --- llama_stack/core/server/server.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index 61ad3e7b3..d58037f82 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -116,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro if isinstance(exc, RequestValidationError): return HTTPException( - status_code=400, + status_code=httpx.codes.BAD_REQUEST, detail={ "errors": [ { @@ -129,20 +129,20 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro }, ) elif isinstance(exc, ValueError): - return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): - return HTTPException(status_code=400, detail=str(exc)) + return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) elif isinstance(exc, PermissionError | AccessDeniedError): - return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") + return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") elif isinstance(exc, asyncio.TimeoutError | TimeoutError): - return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") + return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): - return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + return HTTPException(status_code=httpx.codes.NOT_IMPLEMENTED, detail=f"Not implemented: {str(exc)}") elif isinstance(exc, AuthenticationRequiredError): - return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}") + return HTTPException(status_code=httpx.codes.UNAUTHORIZED, detail=f"Authentication required: {str(exc)}") else: return HTTPException( - status_code=500, + status_code=httpx.codes.INTERNAL_SERVER_ERROR, detail="Internal server error: An unexpected error occurred.", ) @@ -357,7 +357,7 @@ class ClientVersionMiddleware: await send( { "type": "http.response.start", - "status": 426, + "status": httpx.codes.UPGRADE_REQUIRED, "headers": [[b"content-type", b"application/json"]], } ) From a9081d87b903830b907bea66bcda6b340bf8e212 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 13 Aug 2025 09:34:56 -0700 Subject: [PATCH 17/45] feat(ci): update Recording workflow trigger and concurrency group --- .github/workflows/integration-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 9ef49fba3..f330d2c45 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -5,7 +5,7 @@ run-name: Run the integration test suite from tests/integration in replay mode on: push: branches: [ main ] - pull_request: + pull_request_target: branches: [ main ] types: [opened, synchronize, reopened] paths: @@ -34,7 +34,7 @@ on: concurrency: # Skip concurrency for pushes to main - each commit should be tested independently - group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number }} cancel-in-progress: true jobs: From 25e0553eed5a21773369dfd3d81316db2da39629 Mon Sep 17 00:00:00 2001 From: slekkala1 Date: Wed, 13 Aug 2025 09:47:35 -0700 Subject: [PATCH 18/45] chore: Change moderations api response to Provider returned categories (#3098) # What does this PR do? To be compliant with model policies for LLAMA, just return the categories as is from provider, we will lose the OAI compat in moderations api response. ## Test Plan `SAFETY_MODEL=llama-guard3:8b LLAMA_STACK_CONFIG=starter uv run pytest -v tests/integration/safety/test_safety.py --text-model=llama3.2:3b-instruct-fp16 --embedding-model=all-MiniLM-L6-v2 --safety-shield=ollama` --- docs/_static/llama-stack-spec.html | 2 +- docs/_static/llama-stack-spec.yaml | 4 -- llama_stack/apis/safety/safety.py | 37 +------------- llama_stack/core/routers/safety.py | 17 +------ .../inline/safety/llama_guard/llama_guard.py | 48 ++++--------------- .../safety/prompt_guard/prompt_guard.py | 5 +- 6 files changed, 16 insertions(+), 97 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index e2c53d4b0..25f916d87 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -16569,7 +16569,7 @@ "additionalProperties": { "type": "number" }, - "description": "A list of the categories along with their scores as predicted by model. Required set of categories that need to be in response - violence - violence/graphic - harassment - harassment/threatening - hate - hate/threatening - illicit - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - self-harm/instructions" + "description": "A list of the categories along with their scores as predicted by model." }, "user_message": { "type": "string" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 85cec3a78..43e9fa95a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -12322,10 +12322,6 @@ components: type: number description: >- A list of the categories along with their scores as predicted by model. - Required set of categories that need to be in response - violence - violence/graphic - - harassment - harassment/threatening - hate - hate/threatening - illicit - - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - - self-harm/instructions user_message: type: string metadata: diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 3f374460b..25ee03ec1 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum, StrEnum +from enum import Enum from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -15,27 +15,6 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod -# OpenAI Categories to return in the response -class OpenAICategories(StrEnum): - """ - Required set of categories in moderations api response - """ - - VIOLENCE = "violence" - VIOLENCE_GRAPHIC = "violence/graphic" - HARRASMENT = "harassment" - HARRASMENT_THREATENING = "harassment/threatening" - HATE = "hate" - HATE_THREATENING = "hate/threatening" - ILLICIT = "illicit" - ILLICIT_VIOLENT = "illicit/violent" - SEXUAL = "sexual" - SEXUAL_MINORS = "sexual/minors" - SELF_HARM = "self-harm" - SELF_HARM_INTENT = "self-harm/intent" - SELF_HARM_INSTRUCTIONS = "self-harm/instructions" - - @json_schema_type class ModerationObjectResults(BaseModel): """A moderation object. @@ -43,20 +22,6 @@ class ModerationObjectResults(BaseModel): :param categories: A list of the categories, and whether they are flagged or not. :param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to. :param category_scores: A list of the categories along with their scores as predicted by model. - Required set of categories that need to be in response - - violence - - violence/graphic - - harassment - - harassment/threatening - - hate - - hate/threatening - - illicit - - illicit/violent - - sexual - - sexual/minors - - self-harm - - self-harm/intent - - self-harm/instructions """ flagged: bool diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index 9bf2b1bac..c76673d2a 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -10,7 +10,7 @@ from llama_stack.apis.inference import ( Message, ) from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories +from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable @@ -82,20 +82,5 @@ class SafetyRouter(Safety): input=input, model=model, ) - self._validate_required_categories_exist(response) return response - - def _validate_required_categories_exist(self, response: ModerationObject) -> None: - """Validate the ProviderImpl response contains the required Open AI moderations categories.""" - required_categories = list(map(str, OpenAICategories)) - - categories = response.results[0].categories - category_applied_input_types = response.results[0].category_applied_input_types - category_scores = response.results[0].category_scores - - for i in [categories, category_applied_input_types, category_scores]: - if not set(required_categories).issubset(set(i.keys())): - raise ValueError( - f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}" - ) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index f83c39a6a..bae744010 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -22,7 +22,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) -from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api from llama_stack.models.llama.datatypes import Role @@ -72,30 +72,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = { } SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} -OPENAI_TO_LLAMA_CATEGORIES_MAP = { - OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES], - OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES], - OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION], - OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION], - OpenAICategories.HATE: [CAT_HATE], - OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES], - OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES], - OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS], - OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT], - OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION], - OpenAICategories.SELF_HARM: [CAT_SELF_HARM], - OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM], - OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE], - # These are custom categories that are not in the OpenAI moderation categories - "custom/defamation": [CAT_DEFAMATION], - "custom/specialized_advice": [CAT_SPECIALIZED_ADVICE], - "custom/privacy_violation": [CAT_PRIVACY], - "custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY], - "custom/weapons": [CAT_INDISCRIMINATE_WEAPONS], - "custom/elections": [CAT_ELECTIONS], - "custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE], -} - DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_VIOLENT_CRIMES, @@ -424,9 +400,9 @@ class LlamaGuardShield: ModerationObject with appropriate configuration """ # Set default values for safe case - categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False) - category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0) - category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} + categories = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), False) + category_scores = dict.fromkeys(SAFETY_CATEGORIES_TO_CODE_MAP.keys(), 1.0) + category_applied_input_types = {key: [] for key in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} flagged = False user_message = None metadata = {} @@ -453,19 +429,15 @@ class LlamaGuardShield: ], ) - # Get OpenAI categories for the unsafe codes - openai_categories = [] - for code in unsafe_code_list: - llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code] - openai_categories.extend( - k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l - ) + llama_guard_category = [SAFETY_CODE_TO_CATEGORIES_MAP[code] for code in unsafe_code_list] # Update categories for unsafe content - categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} - category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + categories = {k: k in llama_guard_category for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys()} + category_scores = { + k: 1.0 if k in llama_guard_category else 0.0 for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys() + } category_applied_input_types = { - k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP + k: ["text"] if k in llama_guard_category else [] for k in SAFETY_CATEGORIES_TO_CODE_MAP.keys() } flagged = True user_message = CANNED_RESPONSE_TEXT diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 801500dee..c760f0fd1 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -18,6 +18,7 @@ from llama_stack.apis.safety import ( ShieldStore, ViolationLevel, ) +from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -64,8 +65,8 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await self.shield.run(messages) - async def run_moderation(self, input: str | list[str], model: str): - raise NotImplementedError("run_moderation not implemented for PromptGuard") + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + raise NotImplementedError("run_moderation is not implemented for Prompt Guard") class PromptGuardShield: From 2f51273215ded844b62c7532c1e5e53ca46de1d4 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 13 Aug 2025 09:51:35 -0700 Subject: [PATCH 19/45] fix: huge speed boost (#3132) # What does this PR do? make llama stack fast again ## Test Plan --- llama_stack/core/server/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index d58037f82..e9d70fc8d 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -181,7 +181,6 @@ async def sse_generator(event_gen_coroutine): event_gen = await event_gen_coroutine async for item in event_gen: yield create_sse_event(item) - await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("Generator cancelled") if event_gen: From d6ae54723df7f1b95b896e81ab1f2b66e15e6642 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 13 Aug 2025 10:58:22 -0700 Subject: [PATCH 20/45] chore: setup for performance benchmarking (#3096) # What does this PR do? 1. Added a simple mock openai-compat server that serves chat/completion 2. Add a benchmark server in EKS that includes mock inference server 3. Add locust (https://locust.io/) file for load testing ## Test Plan bash apply.sh kubectl port-forward service/locust-web-ui 8089:8089 Go to localhost:8089 to start a load test image image --- .../distributions/k8s-benchmark/apply.sh | 57 +++ .../k8s-benchmark/locust-k8s.yaml | 131 +++++++ .../distributions/k8s-benchmark/locustfile.py | 78 ++++ .../k8s-benchmark/openai-mock-deployment.yaml | 52 +++ .../k8s-benchmark/openai-mock-server.py | 190 ++++++++++ .../k8s-benchmark/stack-configmap.yaml | 143 +++++++ .../k8s-benchmark/stack-k8s.yaml.template | 87 +++++ .../k8s-benchmark/stack_run_config.yaml | 136 +++++++ .../distributions/k8s/stack-k8s.yaml.template | 6 +- pyproject.toml | 3 + uv.lock | 354 ++++++++++++++++++ 11 files changed, 1234 insertions(+), 3 deletions(-) create mode 100755 docs/source/distributions/k8s-benchmark/apply.sh create mode 100644 docs/source/distributions/k8s-benchmark/locust-k8s.yaml create mode 100644 docs/source/distributions/k8s-benchmark/locustfile.py create mode 100644 docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml create mode 100644 docs/source/distributions/k8s-benchmark/openai-mock-server.py create mode 100644 docs/source/distributions/k8s-benchmark/stack-configmap.yaml create mode 100644 docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template create mode 100644 docs/source/distributions/k8s-benchmark/stack_run_config.yaml diff --git a/docs/source/distributions/k8s-benchmark/apply.sh b/docs/source/distributions/k8s-benchmark/apply.sh new file mode 100755 index 000000000..119a1c849 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/apply.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh). + +export MOCK_INFERENCE_PORT=8080 +export STREAM_DELAY_SECONDS=0.005 + +export POSTGRES_USER=llamastack +export POSTGRES_DB=llamastack +export POSTGRES_PASSWORD=llamastack + +export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct +export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B + +export MOCK_INFERENCE_MODEL=mock-inference + +# Use llama-stack-benchmark-service as the benchmark server +export LOCUST_HOST=http://llama-stack-benchmark-service:8323 +export LOCUST_BASE_PATH=/v1/openai/v1 + +# Use vllm-service as the benchmark server +# export LOCUST_HOST=http://vllm-server:8000 +# export LOCUST_BASE_PATH=/v1 + + +export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL + +set -euo pipefail +set -x + +# Deploy benchmark-specific components +# Deploy OpenAI mock server +kubectl create configmap openai-mock --from-file=openai-mock-server.py \ + --dry-run=client -o yaml | kubectl apply --validate=false -f - + +envsubst < openai-mock-deployment.yaml | kubectl apply --validate=false -f - + +# Create configmap with our custom stack config +kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \ + --dry-run=client -o yaml > stack-configmap.yaml + +kubectl apply --validate=false -f stack-configmap.yaml + +# Deploy our custom llama stack server (overriding the base one) +envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f - + +# Deploy Locust load testing +kubectl create configmap locust-script --from-file=locustfile.py \ + --dry-run=client -o yaml | kubectl apply --validate=false -f - + +envsubst < locust-k8s.yaml | kubectl apply --validate=false -f - diff --git a/docs/source/distributions/k8s-benchmark/locust-k8s.yaml b/docs/source/distributions/k8s-benchmark/locust-k8s.yaml new file mode 100644 index 000000000..f20a01b2d --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/locust-k8s.yaml @@ -0,0 +1,131 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: locust-master + labels: + app: locust + role: master +spec: + replicas: 1 + selector: + matchLabels: + app: locust + role: master + template: + metadata: + labels: + app: locust + role: master + spec: + containers: + - name: locust-master + image: locustio/locust:2.31.8 + ports: + - containerPort: 8089 # Web UI + - containerPort: 5557 # Master communication + env: + - name: LOCUST_HOST + value: "${LOCUST_HOST}" + - name: LOCUST_LOCUSTFILE + value: "/locust/locustfile.py" + - name: LOCUST_WEB_HOST + value: "0.0.0.0" + - name: LOCUST_MASTER + value: "true" + - name: LOCUST_BASE_PATH + value: "${LOCUST_BASE_PATH}" + - name: INFERENCE_MODEL + value: "${BENCHMARK_INFERENCE_MODEL}" + volumeMounts: + - name: locust-script + mountPath: /locust + command: ["locust"] + args: + - "--master" + - "--web-host=0.0.0.0" + - "--web-port=8089" + - "--host=${LOCUST_HOST}" + - "--locustfile=/locust/locustfile.py" + volumes: + - name: locust-script + configMap: + name: locust-script +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: locust-worker + labels: + app: locust + role: worker +spec: + replicas: 2 # Start with 2 workers, can be scaled up + selector: + matchLabels: + app: locust + role: worker + template: + metadata: + labels: + app: locust + role: worker + spec: + containers: + - name: locust-worker + image: locustio/locust:2.31.8 + env: + - name: LOCUST_HOST + value: "${LOCUST_HOST}" + - name: LOCUST_LOCUSTFILE + value: "/locust/locustfile.py" + - name: LOCUST_MASTER_HOST + value: "locust-master-service" + - name: LOCUST_MASTER_PORT + value: "5557" + - name: INFERENCE_MODEL + value: "${BENCHMARK_INFERENCE_MODEL}" + - name: LOCUST_BASE_PATH + value: "${LOCUST_BASE_PATH}" + volumeMounts: + - name: locust-script + mountPath: /locust + command: ["locust"] + args: + - "--worker" + - "--master-host=locust-master-service" + - "--master-port=5557" + - "--locustfile=/locust/locustfile.py" + volumes: + - name: locust-script + configMap: + name: locust-script +--- +apiVersion: v1 +kind: Service +metadata: + name: locust-master-service +spec: + selector: + app: locust + role: master + ports: + - name: web-ui + port: 8089 + targetPort: 8089 + - name: master-comm + port: 5557 + targetPort: 5557 + type: ClusterIP +--- +apiVersion: v1 +kind: Service +metadata: + name: locust-web-ui +spec: + selector: + app: locust + role: master + ports: + - port: 8089 + targetPort: 8089 + type: ClusterIP # Keep internal, use port-forward to access diff --git a/docs/source/distributions/k8s-benchmark/locustfile.py b/docs/source/distributions/k8s-benchmark/locustfile.py new file mode 100644 index 000000000..8e511fa95 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/locustfile.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Locust load testing script for Llama Stack with Prism mock OpenAI provider. +""" + +import random +from locust import HttpUser, task, between +import os + +base_path = os.getenv("LOCUST_BASE_PATH", "/v1/openai/v1") + +MODEL_ID = os.getenv("INFERENCE_MODEL") + +class LlamaStackUser(HttpUser): + wait_time = between(0.0, 0.0001) + + def on_start(self): + """Setup authentication and test data.""" + # No auth required for benchmark server + self.headers = { + "Content-Type": "application/json" + } + + # Test messages of varying lengths + self.test_messages = [ + [{"role": "user", "content": "Hi"}], + [{"role": "user", "content": "What is the capital of France?"}], + [{"role": "user", "content": "Explain quantum physics in simple terms."}], + [{"role": "user", "content": "Write a short story about a robot learning to paint."}], + [ + {"role": "user", "content": "What is machine learning?"}, + {"role": "assistant", "content": "Machine learning is a subset of AI..."}, + {"role": "user", "content": "Can you give me a practical example?"} + ] + ] + + @task(weight=100) + def chat_completion_streaming(self): + """Test streaming chat completion (20% of requests).""" + messages = random.choice(self.test_messages) + payload = { + "model": MODEL_ID, + "messages": messages, + "stream": True, + "max_tokens": 100 + } + + with self.client.post( + f"{base_path}/chat/completions", + headers=self.headers, + json=payload, + stream=True, + catch_response=True + ) as response: + if response.status_code == 200: + chunks_received = 0 + try: + for line in response.iter_lines(): + if line: + line_str = line.decode('utf-8') + if line_str.startswith('data: '): + chunks_received += 1 + if line_str.strip() == 'data: [DONE]': + break + + if chunks_received > 0: + response.success() + else: + response.failure("No streaming chunks received") + except Exception as e: + response.failure(f"Streaming error: {e}") + else: + response.failure(f"HTTP {response.status_code}: {response.text}") diff --git a/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml b/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml new file mode 100644 index 000000000..c72921281 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml @@ -0,0 +1,52 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: openai-mock + labels: + app: openai-mock +spec: + replicas: 1 + selector: + matchLabels: + app: openai-mock + template: + metadata: + labels: + app: openai-mock + spec: + containers: + - name: openai-mock + image: python:3.12-slim + ports: + - containerPort: ${MOCK_INFERENCE_PORT} + env: + - name: PORT + value: "${MOCK_INFERENCE_PORT}" + - name: MOCK_MODELS + value: "${MOCK_INFERENCE_MODEL}" + - name: STREAM_DELAY_SECONDS + value: "${STREAM_DELAY_SECONDS}" + command: ["sh", "-c"] + args: + - | + pip install flask && + python /app/openai-mock-server.py --port ${MOCK_INFERENCE_PORT} + volumeMounts: + - name: openai-mock-script + mountPath: /app + volumes: + - name: openai-mock-script + configMap: + name: openai-mock +--- +apiVersion: v1 +kind: Service +metadata: + name: openai-mock-service +spec: + selector: + app: openai-mock + ports: + - port: 8080 + targetPort: 8080 + type: ClusterIP diff --git a/docs/source/distributions/k8s-benchmark/openai-mock-server.py b/docs/source/distributions/k8s-benchmark/openai-mock-server.py new file mode 100644 index 000000000..46c923b60 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/openai-mock-server.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +OpenAI-compatible mock server that returns: +- Hardcoded /models response for consistent validation +- Valid OpenAI-formatted chat completion responses with dynamic content +""" + +from flask import Flask, request, jsonify, Response +import time +import random +import uuid +import json +import argparse +import os + +app = Flask(__name__) + +# Models from environment variables +def get_models(): + models_str = os.getenv("MOCK_MODELS", "mock-inference") + model_ids = [m.strip() for m in models_str.split(",") if m.strip()] + + return { + "object": "list", + "data": [ + { + "id": model_id, + "object": "model", + "created": 1234567890, + "owned_by": "vllm" + } + for model_id in model_ids + ] + } + +def generate_random_text(length=50): + """Generate random but coherent text for responses.""" + words = [ + "Hello", "there", "I'm", "an", "AI", "assistant", "ready", "to", "help", "you", + "with", "your", "questions", "and", "tasks", "today", "Let", "me","know", "what", + "you'd", "like", "to", "discuss", "or", "explore", "together", "I", "can", "assist", + "with", "various", "topics", "including", "coding", "writing", "analysis", "and", "more" + ] + return " ".join(random.choices(words, k=length)) + +@app.route('/models', methods=['GET']) +def list_models(): + models = get_models() + print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}") + return jsonify(models) + +@app.route('/chat/completions', methods=['POST']) +def chat_completions(): + """Return OpenAI-formatted chat completion responses.""" + data = request.get_json() + default_model = get_models()['data'][0]['id'] + model = data.get('model', default_model) + messages = data.get('messages', []) + stream = data.get('stream', False) + + print(f"[MOCK] Chat completion request - model: {model}, stream: {stream}") + + if stream: + return handle_streaming_completion(model, messages) + else: + return handle_non_streaming_completion(model, messages) + +def handle_non_streaming_completion(model, messages): + response_text = generate_random_text(random.randint(20, 80)) + + # Calculate realistic token counts + prompt_tokens = sum(len(str(msg.get('content', '')).split()) for msg in messages) + completion_tokens = len(response_text.split()) + + response = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": response_text + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens + } + } + + return jsonify(response) + +def handle_streaming_completion(model, messages): + def generate_stream(): + # Generate response text + full_response = generate_random_text(random.randint(30, 100)) + words = full_response.split() + + # Send initial chunk + initial_chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""} + } + ] + } + yield f"data: {json.dumps(initial_chunk)}\n\n" + + # Send word by word + for i, word in enumerate(words): + chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": f"{word} " if i < len(words) - 1 else word} + } + ] + } + yield f"data: {json.dumps(chunk)}\n\n" + # Configurable delay to simulate realistic streaming + stream_delay = float(os.getenv("STREAM_DELAY_SECONDS", "0.005")) + time.sleep(stream_delay) + + # Send final chunk + final_chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": ""}, + "finish_reason": "stop" + } + ] + } + yield f"data: {json.dumps(final_chunk)}\n\n" + yield "data: [DONE]\n\n" + + return Response( + generate_stream(), + mimetype='text/event-stream', + headers={ + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'Access-Control-Allow-Origin': '*', + } + ) + +@app.route('/health', methods=['GET']) +def health(): + return jsonify({"status": "healthy", "type": "openai-mock"}) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OpenAI-compatible mock server') + parser.add_argument('--port', type=int, default=8081, + help='Port to run the server on (default: 8081)') + args = parser.parse_args() + + port = args.port + + models = get_models() + print("Starting OpenAI-compatible mock server...") + print(f"- /models endpoint with: {[m['id'] for m in models['data']]}") + print("- OpenAI-formatted chat/completion responses with dynamic content") + print("- Streaming support with valid SSE format") + print(f"- Listening on: http://0.0.0.0:{port}") + app.run(host='0.0.0.0', port=port, debug=False) diff --git a/docs/source/distributions/k8s-benchmark/stack-configmap.yaml b/docs/source/distributions/k8s-benchmark/stack-configmap.yaml new file mode 100644 index 000000000..653e66756 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack-configmap.yaml @@ -0,0 +1,143 @@ +apiVersion: v1 +data: + stack_run_config.yaml: | + version: '2' + image_name: kubernetes-benchmark-demo + apis: + - agents + - inference + - safety + - telemetry + - tool_runtime + - vector_io + providers: + inference: + - provider_id: vllm-inference + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: vllm-safety + provider_type: remote::vllm + config: + url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: mock-vllm-inference + provider_type: remote::vllm + config: + url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT} + max_tokens: 4096 + api_token: fake + tls_verify: false + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: ${env.ENABLE_CHROMADB:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + responses_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} + metadata_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + table_name: llamastack_kvstore + inference_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + models: + - metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding + - model_id: ${env.INFERENCE_MODEL} + provider_id: vllm-inference + model_type: llm + - model_id: ${env.SAFETY_MODEL} + provider_id: vllm-safety + model_type: llm + - model_id: ${env.MOCK_INFERENCE_MODEL} + provider_id: mock-vllm-inference + model_type: llm + shields: + - shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} + vector_dbs: [] + datasets: [] + scoring_fns: [] + benchmarks: [] + tool_groups: + - toolgroup_id: builtin::websearch + provider_id: tavily-search + - toolgroup_id: builtin::rag + provider_id: rag-runtime + server: + port: 8323 +kind: ConfigMap +metadata: + creationTimestamp: null + name: llama-stack-config diff --git a/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template b/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template new file mode 100644 index 000000000..bc14d5124 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template @@ -0,0 +1,87 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: llama-benchmark-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 1Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: llama-stack-benchmark-server +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + template: + metadata: + labels: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + spec: + containers: + - name: llama-stack-benchmark + image: llamastack/distribution-starter:latest + imagePullPolicy: Always # since we have specified latest instead of a version + env: + - name: ENABLE_CHROMADB + value: "true" + - name: CHROMADB_URL + value: http://chromadb.default.svc.cluster.local:6000 + - name: POSTGRES_HOST + value: postgres-server.default.svc.cluster.local + - name: POSTGRES_PORT + value: "5432" + - name: INFERENCE_MODEL + value: "${INFERENCE_MODEL}" + - name: SAFETY_MODEL + value: "${SAFETY_MODEL}" + - name: TAVILY_SEARCH_API_KEY + value: "${TAVILY_SEARCH_API_KEY}" + - name: MOCK_INFERENCE_PORT + value: "${MOCK_INFERENCE_PORT}" + - name: VLLM_URL + value: http://vllm-server.default.svc.cluster.local:8000/v1 + - name: VLLM_MAX_TOKENS + value: "3072" + - name: VLLM_SAFETY_URL + value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 + - name: VLLM_TLS_VERIFY + value: "false" + - name: MOCK_INFERENCE_MODEL + value: "${MOCK_INFERENCE_MODEL}" + command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] + ports: + - containerPort: 8323 + volumeMounts: + - name: llama-storage + mountPath: /root/.llama + - name: llama-config + mountPath: /etc/config + volumes: + - name: llama-storage + persistentVolumeClaim: + claimName: llama-benchmark-pvc + - name: llama-config + configMap: + name: llama-stack-config +--- +apiVersion: v1 +kind: Service +metadata: + name: llama-stack-benchmark-service +spec: + selector: + app.kubernetes.io/name: llama-stack-benchmark + app.kubernetes.io/component: server + ports: + - name: http + port: 8323 + targetPort: 8323 + type: ClusterIP diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml new file mode 100644 index 000000000..ad56be047 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -0,0 +1,136 @@ +version: '2' +image_name: kubernetes-benchmark-demo +apis: +- agents +- inference +- safety +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: vllm-inference + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: vllm-safety + provider_type: remote::vllm + config: + url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: mock-vllm-inference + provider_type: remote::vllm + config: + url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT} + max_tokens: 4096 + api_token: fake + tls_verify: false + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: ${env.ENABLE_CHROMADB:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + responses_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:+} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} + table_name: llamastack_kvstore +inference_store: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} +models: +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding +- model_id: ${env.INFERENCE_MODEL} + provider_id: vllm-inference + model_type: llm +- model_id: ${env.SAFETY_MODEL} + provider_id: vllm-safety + model_type: llm +- model_id: ${env.MOCK_INFERENCE_MODEL} + provider_id: mock-vllm-inference + model_type: llm +shields: +- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8323 diff --git a/docs/source/distributions/k8s/stack-k8s.yaml.template b/docs/source/distributions/k8s/stack-k8s.yaml.template index ad5d2c716..dfc049f4f 100644 --- a/docs/source/distributions/k8s/stack-k8s.yaml.template +++ b/docs/source/distributions/k8s/stack-k8s.yaml.template @@ -40,19 +40,19 @@ spec: value: "3072" - name: VLLM_SAFETY_URL value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 + - name: VLLM_TLS_VERIFY + value: "false" - name: POSTGRES_HOST value: postgres-server.default.svc.cluster.local - name: POSTGRES_PORT value: "5432" - - name: VLLM_TLS_VERIFY - value: "false" - name: INFERENCE_MODEL value: "${INFERENCE_MODEL}" - name: SAFETY_MODEL value: "${SAFETY_MODEL}" - name: TAVILY_SEARCH_API_KEY value: "${TAVILY_SEARCH_API_KEY}" - command: ["python", "-m", "llama_stack.core.server.server", "--config", "/etc/config/stack_run_config.yaml", "--port", "8321"] + command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8321"] ports: - containerPort: 8321 volumeMounts: diff --git a/pyproject.toml b/pyproject.toml index 1b0850631..db0ad1f00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,9 @@ docs = [ "requests", ] codegen = ["rich", "pydantic", "jinja2>=3.1.6"] +benchmark = [ + "locust>=2.37.14", +] [project.urls] Homepage = "https://github.com/meta-llama/llama-stack" diff --git a/uv.lock b/uv.lock index 9f4ba4adb..4c56816ef 100644 --- a/uv.lock +++ b/uv.lock @@ -290,6 +290,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/cf/45fb5261ece3e6b9817d3d82b2f343a505fd58674a92577923bc500bd1aa/bcrypt-4.3.0-cp39-abi3-win_amd64.whl", hash = "sha256:e53e074b120f2877a35cc6c736b8eb161377caae8925c17688bd46ba56daaa5b", size = 152799, upload-time = "2025-02-28T01:23:53.139Z" }, ] +[[package]] +name = "bidict" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/6e/026678aa5a830e07cd9498a05d3e7e650a4f56a42f267a53d22bcda1bdc9/bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", size = 29093, upload-time = "2024-02-18T19:09:05.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/37/e8730c3587a65eb5645d4aba2d27aae48e8003614d6aaf15dda67f702f1f/bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5", size = 32764, upload-time = "2024-02-18T19:09:04.156Z" }, +] + [[package]] name = "black" version = "25.1.0" @@ -347,6 +356,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/f6/776782c20b71b4da290ed0b25ccec0cbfca924d20f2ec26078876bce6d29/braintrust_core-0.0.59-py3-none-any.whl", hash = "sha256:b9be128e1c1b4c376f082e81d314c1938aa9b8c0398ab11df4ad29fad8e655c1", size = 4441, upload-time = "2025-05-12T22:05:12.088Z" }, ] +[[package]] +name = "brotli" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/c2/f9e977608bdf958650638c3f1e28f85a1b075f075ebbe77db8555463787b/Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724", size = 7372270, upload-time = "2023-09-07T14:05:41.643Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/d0/5373ae13b93fe00095a58efcbce837fd470ca39f703a235d2a999baadfbc/Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28", size = 815693, upload-time = "2024-10-18T12:32:23.824Z" }, + { url = "https://files.pythonhosted.org/packages/8e/48/f6e1cdf86751300c288c1459724bfa6917a80e30dbfc326f92cea5d3683a/Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f", size = 422489, upload-time = "2024-10-18T12:32:25.641Z" }, + { url = "https://files.pythonhosted.org/packages/06/88/564958cedce636d0f1bed313381dfc4b4e3d3f6015a63dae6146e1b8c65c/Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409", size = 873081, upload-time = "2023-09-07T14:03:57.967Z" }, + { url = "https://files.pythonhosted.org/packages/58/79/b7026a8bb65da9a6bb7d14329fd2bd48d2b7f86d7329d5cc8ddc6a90526f/Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2", size = 446244, upload-time = "2023-09-07T14:03:59.319Z" }, + { url = "https://files.pythonhosted.org/packages/e5/18/c18c32ecea41b6c0004e15606e274006366fe19436b6adccc1ae7b2e50c2/Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451", size = 2906505, upload-time = "2023-09-07T14:04:01.327Z" }, + { url = "https://files.pythonhosted.org/packages/08/c8/69ec0496b1ada7569b62d85893d928e865df29b90736558d6c98c2031208/Brotli-1.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7f4bf76817c14aa98cc6697ac02f3972cb8c3da93e9ef16b9c66573a68014f91", size = 2944152, upload-time = "2023-09-07T14:04:03.033Z" }, + { url = "https://files.pythonhosted.org/packages/ab/fb/0517cea182219d6768113a38167ef6d4eb157a033178cc938033a552ed6d/Brotli-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d0c5516f0aed654134a2fc936325cc2e642f8a0e096d075209672eb321cff408", size = 2919252, upload-time = "2023-09-07T14:04:04.675Z" }, + { url = "https://files.pythonhosted.org/packages/c7/53/73a3431662e33ae61a5c80b1b9d2d18f58dfa910ae8dd696e57d39f1a2f5/Brotli-1.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6c3020404e0b5eefd7c9485ccf8393cfb75ec38ce75586e046573c9dc29967a0", size = 2845955, upload-time = "2023-09-07T14:04:06.585Z" }, + { url = "https://files.pythonhosted.org/packages/55/ac/bd280708d9c5ebdbf9de01459e625a3e3803cce0784f47d633562cf40e83/Brotli-1.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:4ed11165dd45ce798d99a136808a794a748d5dc38511303239d4e2363c0695dc", size = 2914304, upload-time = "2023-09-07T14:04:08.668Z" }, + { url = "https://files.pythonhosted.org/packages/76/58/5c391b41ecfc4527d2cc3350719b02e87cb424ef8ba2023fb662f9bf743c/Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180", size = 2814452, upload-time = "2023-09-07T14:04:10.736Z" }, + { url = "https://files.pythonhosted.org/packages/c7/4e/91b8256dfe99c407f174924b65a01f5305e303f486cc7a2e8a5d43c8bec3/Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248", size = 2938751, upload-time = "2023-09-07T14:04:12.875Z" }, + { url = "https://files.pythonhosted.org/packages/5a/a6/e2a39a5d3b412938362bbbeba5af904092bf3f95b867b4a3eb856104074e/Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966", size = 2933757, upload-time = "2023-09-07T14:04:14.551Z" }, + { url = "https://files.pythonhosted.org/packages/13/f0/358354786280a509482e0e77c1a5459e439766597d280f28cb097642fc26/Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9", size = 2936146, upload-time = "2024-10-18T12:32:27.257Z" }, + { url = "https://files.pythonhosted.org/packages/80/f7/daf538c1060d3a88266b80ecc1d1c98b79553b3f117a485653f17070ea2a/Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb", size = 2848055, upload-time = "2024-10-18T12:32:29.376Z" }, + { url = "https://files.pythonhosted.org/packages/ad/cf/0eaa0585c4077d3c2d1edf322d8e97aabf317941d3a72d7b3ad8bce004b0/Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111", size = 3035102, upload-time = "2024-10-18T12:32:31.371Z" }, + { url = "https://files.pythonhosted.org/packages/d8/63/1c1585b2aa554fe6dbce30f0c18bdbc877fa9a1bf5ff17677d9cca0ac122/Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839", size = 2930029, upload-time = "2024-10-18T12:32:33.293Z" }, + { url = "https://files.pythonhosted.org/packages/5f/3b/4e3fd1893eb3bbfef8e5a80d4508bec17a57bb92d586c85c12d28666bb13/Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0", size = 333276, upload-time = "2023-09-07T14:04:16.49Z" }, + { url = "https://files.pythonhosted.org/packages/3d/d5/942051b45a9e883b5b6e98c041698b1eb2012d25e5948c58d6bf85b1bb43/Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951", size = 357255, upload-time = "2023-09-07T14:04:17.83Z" }, + { url = "https://files.pythonhosted.org/packages/0a/9f/fb37bb8ffc52a8da37b1c03c459a8cd55df7a57bdccd8831d500e994a0ca/Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5", size = 815681, upload-time = "2024-10-18T12:32:34.942Z" }, + { url = "https://files.pythonhosted.org/packages/06/b3/dbd332a988586fefb0aa49c779f59f47cae76855c2d00f450364bb574cac/Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8", size = 422475, upload-time = "2024-10-18T12:32:36.485Z" }, + { url = "https://files.pythonhosted.org/packages/bb/80/6aaddc2f63dbcf2d93c2d204e49c11a9ec93a8c7c63261e2b4bd35198283/Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f", size = 2906173, upload-time = "2024-10-18T12:32:37.978Z" }, + { url = "https://files.pythonhosted.org/packages/ea/1d/e6ca79c96ff5b641df6097d299347507d39a9604bde8915e76bf026d6c77/Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648", size = 2943803, upload-time = "2024-10-18T12:32:39.606Z" }, + { url = "https://files.pythonhosted.org/packages/ac/a3/d98d2472e0130b7dd3acdbb7f390d478123dbf62b7d32bda5c830a96116d/Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0", size = 2918946, upload-time = "2024-10-18T12:32:41.679Z" }, + { url = "https://files.pythonhosted.org/packages/c4/a5/c69e6d272aee3e1423ed005d8915a7eaa0384c7de503da987f2d224d0721/Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089", size = 2845707, upload-time = "2024-10-18T12:32:43.478Z" }, + { url = "https://files.pythonhosted.org/packages/58/9f/4149d38b52725afa39067350696c09526de0125ebfbaab5acc5af28b42ea/Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368", size = 2936231, upload-time = "2024-10-18T12:32:45.224Z" }, + { url = "https://files.pythonhosted.org/packages/5a/5a/145de884285611838a16bebfdb060c231c52b8f84dfbe52b852a15780386/Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c", size = 2848157, upload-time = "2024-10-18T12:32:46.894Z" }, + { url = "https://files.pythonhosted.org/packages/50/ae/408b6bfb8525dadebd3b3dd5b19d631da4f7d46420321db44cd99dcf2f2c/Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284", size = 3035122, upload-time = "2024-10-18T12:32:48.844Z" }, + { url = "https://files.pythonhosted.org/packages/af/85/a94e5cfaa0ca449d8f91c3d6f78313ebf919a0dbd55a100c711c6e9655bc/Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7", size = 2930206, upload-time = "2024-10-18T12:32:51.198Z" }, + { url = "https://files.pythonhosted.org/packages/c2/f0/a61d9262cd01351df22e57ad7c34f66794709acab13f34be2675f45bf89d/Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0", size = 333804, upload-time = "2024-10-18T12:32:52.661Z" }, + { url = "https://files.pythonhosted.org/packages/7e/c1/ec214e9c94000d1c1974ec67ced1c970c148aa6b8d8373066123fc3dbf06/Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b", size = 358517, upload-time = "2024-10-18T12:32:54.066Z" }, +] + [[package]] name = "build" version = "1.2.2.post1" @@ -558,6 +605,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417", size = 7294, upload-time = "2025-07-25T14:02:02.896Z" }, ] +[[package]] +name = "configargparse" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/85/4d/6c9ef746dfcc2a32e26f3860bb4a011c008c392b83eabdfb598d1a8bbe5d/configargparse-1.7.1.tar.gz", hash = "sha256:79c2ddae836a1e5914b71d58e4b9adbd9f7779d4e6351a637b7d2d9b6c46d3d9", size = 43958, upload-time = "2025-05-23T14:26:17.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/28/d28211d29bcc3620b1fece85a65ce5bb22f18670a03cd28ea4b75ede270c/configargparse-1.7.1-py3-none-any.whl", hash = "sha256:8b586a31f9d873abd1ca527ffbe58863c99f36d896e2829779803125e83be4b6", size = 25607, upload-time = "2025-05-23T14:26:15.923Z" }, +] + [[package]] name = "coverage" version = "7.10.1" @@ -872,6 +928,49 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/6b/b6/82c7e601d6d3c3278c40b7bd35e17e82aa227f050aa9f66cb7b7fce29471/fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf", size = 87189, upload-time = "2024-10-01T14:29:31.585Z" } +[[package]] +name = "flask" +version = "3.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "blinker" }, + { name = "click" }, + { name = "itsdangerous" }, + { name = "jinja2" }, + { name = "markupsafe" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/de/e47735752347f4128bcf354e0da07ef311a78244eba9e3dc1d4a5ab21a98/flask-3.1.1.tar.gz", hash = "sha256:284c7b8f2f58cb737f0cf1c30fd7eaf0ccfcde196099d24ecede3fc2005aa59e", size = 753440, upload-time = "2025-05-13T15:01:17.447Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/68/9d4508e893976286d2ead7f8f571314af6c2037af34853a30fd769c02e9d/flask-3.1.1-py3-none-any.whl", hash = "sha256:07aae2bb5eaf77993ef57e357491839f5fd9f4dc281593a81a9e4d79a24f295c", size = 103305, upload-time = "2025-05-13T15:01:15.591Z" }, +] + +[[package]] +name = "flask-cors" +version = "6.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flask" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/37/bcfa6c7d5eec777c4c7cf45ce6b27631cebe5230caf88d85eadd63edd37a/flask_cors-6.0.1.tar.gz", hash = "sha256:d81bcb31f07b0985be7f48406247e9243aced229b7747219160a0559edd678db", size = 13463, upload-time = "2025-06-11T01:32:08.518Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/f8/01bf35a3afd734345528f98d0353f2a978a476528ad4d7e78b70c4d149dd/flask_cors-6.0.1-py3-none-any.whl", hash = "sha256:c7b2cbfb1a31aa0d2e5341eea03a6805349f7a61647daee1a15c46bbe981494c", size = 13244, upload-time = "2025-06-11T01:32:07.352Z" }, +] + +[[package]] +name = "flask-login" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flask" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/6e/2f4e13e373bb49e68c02c51ceadd22d172715a06716f9299d9df01b6ddb2/Flask-Login-0.6.3.tar.gz", hash = "sha256:5e23d14a607ef12806c699590b89d0f0e0d67baeec599d75947bf9c147330333", size = 48834, upload-time = "2023-10-30T14:53:21.151Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/f5/67e9cc5c2036f58115f9fe0f00d203cf6780c3ff8ae0e705e7a9d9e8ff9e/Flask_Login-0.6.3-py3-none-any.whl", hash = "sha256:849b25b82a436bf830a054e74214074af59097171562ab10bfa999e6b78aae5d", size = 17303, upload-time = "2023-10-30T14:53:19.636Z" }, +] + [[package]] name = "flatbuffers" version = "25.2.10" @@ -955,6 +1054,77 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "gevent" +version = "25.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation == 'CPython' and sys_platform == 'win32'" }, + { name = "greenlet", marker = "platform_python_implementation == 'CPython'" }, + { name = "zope-event" }, + { name = "zope-interface" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/58/267e8160aea00ab00acd2de97197eecfe307064a376fb5c892870a8a6159/gevent-25.5.1.tar.gz", hash = "sha256:582c948fa9a23188b890d0bc130734a506d039a2e5ad87dae276a456cc683e61", size = 6388207, upload-time = "2025-05-12T12:57:59.833Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/c5/cf71423666a0b83db3d7e3f85788bc47d573fca5fe62b798fe2c4273de7c/gevent-25.5.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:d87c0a1bd809d8f70f96b9b229779ec6647339830b8888a192beed33ac8d129f", size = 2909333, upload-time = "2025-05-12T11:11:34.883Z" }, + { url = "https://files.pythonhosted.org/packages/26/7e/d2f174ee8bec6eb85d961ca203bc599d059c857b8412e367b8fa206603a5/gevent-25.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b87a4b66edb3808d4d07bbdb0deed5a710cf3d3c531e082759afd283758bb649", size = 1788420, upload-time = "2025-05-12T11:52:30.306Z" }, + { url = "https://files.pythonhosted.org/packages/fe/f3/3aba8c147b9108e62ba348c726fe38ae69735a233db425565227336e8ce6/gevent-25.5.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f076779050029a82feb0cb1462021d3404d22f80fa76a181b1a7889cd4d6b519", size = 1868854, upload-time = "2025-05-12T11:54:21.564Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b1/11a5453f8fcebe90a456471fad48bd154c6a62fcb96e3475a5e408d05fc8/gevent-25.5.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bb673eb291c19370f69295f7a881a536451408481e2e3deec3f41dedb7c281ec", size = 1833946, upload-time = "2025-05-12T12:00:05.514Z" }, + { url = "https://files.pythonhosted.org/packages/70/1c/37d4a62303f86e6af67660a8df38c1171b7290df61b358e618c6fea79567/gevent-25.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1325ed44225c8309c0dd188bdbbbee79e1df8c11ceccac226b861c7d52e4837", size = 2070583, upload-time = "2025-05-12T11:33:02.803Z" }, + { url = "https://files.pythonhosted.org/packages/4b/8f/3b14929ff28263aba1d268ea97bcf104be1a86ba6f6bb4633838e7a1905e/gevent-25.5.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:fcd5bcad3102bde686d0adcc341fade6245186050ce14386d547ccab4bd54310", size = 1808341, upload-time = "2025-05-12T11:59:59.154Z" }, + { url = "https://files.pythonhosted.org/packages/2f/fc/674ec819fb8a96e482e4d21f8baa43d34602dba09dfce7bbdc8700899d1b/gevent-25.5.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1a93062609e8fa67ec97cd5fb9206886774b2a09b24887f40148c9c37e6fb71c", size = 2137974, upload-time = "2025-05-12T11:40:54.78Z" }, + { url = "https://files.pythonhosted.org/packages/05/9a/048b7f5e28c54e4595ad4a8ad3c338fa89560e558db2bbe8273f44f030de/gevent-25.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:2534c23dc32bed62b659ed4fd9e198906179e68b26c9276a897e04163bdde806", size = 1638344, upload-time = "2025-05-12T12:08:31.776Z" }, + { url = "https://files.pythonhosted.org/packages/10/25/2162b38d7b48e08865db6772d632bd1648136ce2bb50e340565e45607cad/gevent-25.5.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a022a9de9275ce0b390b7315595454258c525dc8287a03f1a6cacc5878ab7cbc", size = 2928044, upload-time = "2025-05-12T11:11:36.33Z" }, + { url = "https://files.pythonhosted.org/packages/1b/e0/dbd597a964ed00176da122ea759bf2a6c1504f1e9f08e185379f92dc355f/gevent-25.5.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3fae8533f9d0ef3348a1f503edcfb531ef7a0236b57da1e24339aceb0ce52922", size = 1788751, upload-time = "2025-05-12T11:52:32.643Z" }, + { url = "https://files.pythonhosted.org/packages/f1/74/960cc4cf4c9c90eafbe0efc238cdf588862e8e278d0b8c0d15a0da4ed480/gevent-25.5.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c7b32d9c3b5294b39ea9060e20c582e49e1ec81edbfeae6cf05f8ad0829cb13d", size = 1869766, upload-time = "2025-05-12T11:54:23.903Z" }, + { url = "https://files.pythonhosted.org/packages/56/78/fa84b1c7db79b156929685db09a7c18c3127361dca18a09e998e98118506/gevent-25.5.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b95815fe44f318ebbfd733b6428b4cb18cc5e68f1c40e8501dd69cc1f42a83d", size = 1835358, upload-time = "2025-05-12T12:00:06.794Z" }, + { url = "https://files.pythonhosted.org/packages/00/5c/bfefe3822bbca5b83bfad256c82251b3f5be13d52d14e17a786847b9b625/gevent-25.5.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d316529b70d325b183b2f3f5cde958911ff7be12eb2b532b5c301f915dbbf1e", size = 2073071, upload-time = "2025-05-12T11:33:04.2Z" }, + { url = "https://files.pythonhosted.org/packages/20/e4/08a77a3839a37db96393dea952e992d5846a881b887986dde62ead6b48a1/gevent-25.5.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f6ba33c13db91ffdbb489a4f3d177a261ea1843923e1d68a5636c53fe98fa5ce", size = 1809805, upload-time = "2025-05-12T12:00:00.537Z" }, + { url = "https://files.pythonhosted.org/packages/2b/ac/28848348f790c1283df74b0fc0a554271d0606676470f848eccf84eae42a/gevent-25.5.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:37ee34b77c7553777c0b8379915f75934c3f9c8cd32f7cd098ea43c9323c2276", size = 2138305, upload-time = "2025-05-12T11:40:56.566Z" }, + { url = "https://files.pythonhosted.org/packages/52/9e/0e9e40facd2d714bfb00f71fc6dacaacc82c24c1c2e097bf6461e00dec9f/gevent-25.5.1-cp313-cp313-win_amd64.whl", hash = "sha256:9fa6aa0da224ed807d3b76cdb4ee8b54d4d4d5e018aed2478098e685baae7896", size = 1637444, upload-time = "2025-05-12T12:17:45.995Z" }, + { url = "https://files.pythonhosted.org/packages/60/16/b71171e97ec7b4ded8669542f4369d88d5a289e2704efbbde51e858e062a/gevent-25.5.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:0bacf89a65489d26c7087669af89938d5bfd9f7afb12a07b57855b9fad6ccbd0", size = 2937113, upload-time = "2025-05-12T11:12:03.191Z" }, +] + +[[package]] +name = "geventhttpclient" +version = "2.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "brotli" }, + { name = "certifi" }, + { name = "gevent" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/19/1ca8de73dcc0596d3df01be299e940d7fc3bccbeb6f62bb8dd2d427a3a50/geventhttpclient-2.3.4.tar.gz", hash = "sha256:1749f75810435a001fc6d4d7526c92cf02b39b30ab6217a886102f941c874222", size = 83545, upload-time = "2025-06-11T13:18:14.144Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/72/dcbc6dbf838549b7b0c2c18c1365d2580eb7456939e4b608c3ab213fce78/geventhttpclient-2.3.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9ac30c38d86d888b42bb2ab2738ab9881199609e9fa9a153eb0c66fc9188c6cb", size = 71984, upload-time = "2025-06-11T13:17:09.126Z" }, + { url = "https://files.pythonhosted.org/packages/4c/f9/74aa8c556364ad39b238919c954a0da01a6154ad5e85a1d1ab5f9f5ac186/geventhttpclient-2.3.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b802000a4fad80fa57e895009671d6e8af56777e3adf0d8aee0807e96188fd9", size = 52631, upload-time = "2025-06-11T13:17:10.061Z" }, + { url = "https://files.pythonhosted.org/packages/11/1a/bc4b70cba8b46be8b2c6ca5b8067c4f086f8c90915eb68086ab40ff6243d/geventhttpclient-2.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:461e4d9f4caee481788ec95ac64e0a4a087c1964ddbfae9b6f2dc51715ba706c", size = 51991, upload-time = "2025-06-11T13:17:11.049Z" }, + { url = "https://files.pythonhosted.org/packages/b0/f5/8d0f1e998f6d933c251b51ef92d11f7eb5211e3cd579018973a2b455f7c5/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41f2dcc0805551ea9d49f9392c3b9296505a89b9387417b148655d0d8251b36e", size = 119012, upload-time = "2025-06-11T13:17:11.956Z" }, + { url = "https://files.pythonhosted.org/packages/ea/0e/59e4ab506b3c19fc72e88ca344d150a9028a00c400b1099637100bec26fc/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:62f3a29bf242ecca6360d497304900683fd8f42cbf1de8d0546c871819251dad", size = 124565, upload-time = "2025-06-11T13:17:12.896Z" }, + { url = "https://files.pythonhosted.org/packages/39/5d/dcbd34dfcda0c016b4970bd583cb260cc5ebfc35b33d0ec9ccdb2293587a/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8714a3f2c093aeda3ffdb14c03571d349cb3ed1b8b461d9f321890659f4a5dbf", size = 115573, upload-time = "2025-06-11T13:17:13.937Z" }, + { url = "https://files.pythonhosted.org/packages/03/51/89af99e4805e9ce7f95562dfbd23c0b0391830831e43d58f940ec74489ac/geventhttpclient-2.3.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b11f38b74bab75282db66226197024a731250dcbe25542fd4e85ac5313547332", size = 114260, upload-time = "2025-06-11T13:17:14.913Z" }, + { url = "https://files.pythonhosted.org/packages/b3/ec/3a3000bda432953abcc6f51d008166fa7abc1eeddd1f0246933d83854f73/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:fccc2023a89dfbce2e1b1409b967011e45d41808df81b7fa0259397db79ba647", size = 111592, upload-time = "2025-06-11T13:17:15.879Z" }, + { url = "https://files.pythonhosted.org/packages/d8/a3/88fd71fe6bbe1315a2d161cbe2cc7810c357d99bced113bea1668ede8bcf/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9d54b8e9a44890159ae36ba4ae44efd8bb79ff519055137a340d357538a68aa3", size = 113216, upload-time = "2025-06-11T13:17:16.883Z" }, + { url = "https://files.pythonhosted.org/packages/52/eb/20435585a6911b26e65f901a827ef13551c053133926f8c28a7cca0fb08e/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:407cb68a3c3a2c4f5d503930298f2b26ae68137d520e8846d8e230a9981d9334", size = 118450, upload-time = "2025-06-11T13:17:17.968Z" }, + { url = "https://files.pythonhosted.org/packages/2f/79/82782283d613570373990b676a0966c1062a38ca8f41a0f20843c5808e01/geventhttpclient-2.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:54fbbcca2dcf06f12a337dd8f98417a09a49aa9d9706aa530fc93acb59b7d83c", size = 112226, upload-time = "2025-06-11T13:17:18.942Z" }, + { url = "https://files.pythonhosted.org/packages/9c/c4/417d12fc2a31ad93172b03309c7f8c3a8bbd0cf25b95eb7835de26b24453/geventhttpclient-2.3.4-cp312-cp312-win32.whl", hash = "sha256:83143b41bde2eb010c7056f142cb764cfbf77f16bf78bda2323a160767455cf5", size = 48365, upload-time = "2025-06-11T13:17:20.096Z" }, + { url = "https://files.pythonhosted.org/packages/cf/f4/7e5ee2f460bbbd09cb5d90ff63a1cf80d60f1c60c29dac20326324242377/geventhttpclient-2.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:46eda9a9137b0ca7886369b40995d2a43a5dff033d0a839a54241015d1845d41", size = 48961, upload-time = "2025-06-11T13:17:21.111Z" }, + { url = "https://files.pythonhosted.org/packages/ff/ad/132fddde6e2dca46d6a86316962437acd2bfaeb264db4e0fae83c529eb04/geventhttpclient-2.3.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:be64c5583884c407fc748dedbcb083475d5b138afb23c6bc0836cbad228402cc", size = 71967, upload-time = "2025-06-11T13:17:22.121Z" }, + { url = "https://files.pythonhosted.org/packages/f4/34/5e77d9a31d93409a8519cf573843288565272ae5a016be9c9293f56c50a1/geventhttpclient-2.3.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:15b2567137734183efda18e4d6245b18772e648b6a25adea0eba8b3a8b0d17e8", size = 52632, upload-time = "2025-06-11T13:17:23.016Z" }, + { url = "https://files.pythonhosted.org/packages/47/d2/cf0dbc333304700e68cee9347f654b56e8b0f93a341b8b0d027ee96800d6/geventhttpclient-2.3.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a4bca1151b8cd207eef6d5cb3c720c562b2aa7293cf113a68874e235cfa19c31", size = 51980, upload-time = "2025-06-11T13:17:23.933Z" }, + { url = "https://files.pythonhosted.org/packages/ec/5b/c0c30ccd9d06c603add3f2d6abd68bd98430ee9730dc5478815759cf07f7/geventhttpclient-2.3.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b50d9daded5d36193d67e2fc30e59752262fcbbdc86e8222c7df6b93af0346a", size = 118987, upload-time = "2025-06-11T13:17:24.97Z" }, + { url = "https://files.pythonhosted.org/packages/4f/56/095a46af86476372064128162eccbd2ba4a7721503759890d32ea701d5fd/geventhttpclient-2.3.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fe705e7656bc6982a463a4ed7f9b1db8c78c08323f1d45d0d1d77063efa0ce96", size = 124519, upload-time = "2025-06-11T13:17:25.933Z" }, + { url = "https://files.pythonhosted.org/packages/ae/12/7c9ba94b58f7954a83d33183152ce6bf5bda10c08ebe47d79a314cd33e29/geventhttpclient-2.3.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:69668589359db4cbb9efa327dda5735d1e74145e6f0a9ffa50236d15cf904053", size = 115574, upload-time = "2025-06-11T13:17:27.331Z" }, + { url = "https://files.pythonhosted.org/packages/73/77/c4e7c5bce0199428fdb811d6adf6e347180d89eaa1b9b723f711f6bbc830/geventhttpclient-2.3.4-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9ba526e07ccaf4f1c2cd3395dda221139f01468b6eee1190d4a616f187a0378", size = 114222, upload-time = "2025-06-11T13:17:28.289Z" }, + { url = "https://files.pythonhosted.org/packages/a3/79/58802d300950dbd7d4e31eb24afd7c270fc7900ff3923fd266cc915bb086/geventhttpclient-2.3.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:525bd192705b5cb41a7cc3fe41fca194bfd6b5b59997ab9fe68fe0a82dab6140", size = 111682, upload-time = "2025-06-11T13:17:29.291Z" }, + { url = "https://files.pythonhosted.org/packages/d3/9c/ae04e4033459b8142788dad80d8d0b42d460bc6db9150e0815c2d0a02cb4/geventhttpclient-2.3.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:42b6f6afb0d3aab6a013c9cdb97e19bf4fe08695975670d0a018113d24cb344c", size = 113252, upload-time = "2025-06-11T13:17:30.357Z" }, + { url = "https://files.pythonhosted.org/packages/d3/67/5ae5d5878b06397a7b54334d1d31bb78cefc950ae890c2b8f5c917eb271e/geventhttpclient-2.3.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:227579b703085c4e5c6d5217ad6565b19ac8d1164404133e5874efaae1905114", size = 118426, upload-time = "2025-06-11T13:17:31.363Z" }, + { url = "https://files.pythonhosted.org/packages/ca/36/9065bb51f261950c42eddf8718e01a9ff344d8082e31317a8b6677be9bd6/geventhttpclient-2.3.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8d1d0db89c1c8f3282eac9a22fda2b4082e1ed62a2107f70e3f1de1872c7919f", size = 112245, upload-time = "2025-06-11T13:17:32.331Z" }, + { url = "https://files.pythonhosted.org/packages/21/7e/08a615bec095c288f997951e42e48b262d43c6081bef33cfbfad96ab9658/geventhttpclient-2.3.4-cp313-cp313-win32.whl", hash = "sha256:4e492b9ab880f98f8a9cc143b96ea72e860946eae8ad5fb2837cede2a8f45154", size = 48360, upload-time = "2025-06-11T13:17:33.349Z" }, + { url = "https://files.pythonhosted.org/packages/ec/19/ef3cb21e7e95b14cfcd21e3ba7fe3d696e171682dfa43ab8c0a727cac601/geventhttpclient-2.3.4-cp313-cp313-win_amd64.whl", hash = "sha256:72575c5b502bf26ececccb905e4e028bb922f542946be701923e726acf305eb6", size = 48956, upload-time = "2025-06-11T13:17:34.956Z" }, +] + [[package]] name = "gitdb" version = "4.0.12" @@ -1342,6 +1512,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, ] +[[package]] +name = "itsdangerous" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, +] + [[package]] name = "jedi" version = "0.19.2" @@ -1580,6 +1759,9 @@ ui = [ ] [package.dev-dependencies] +benchmark = [ + { name = "locust" }, +] codegen = [ { name = "jinja2" }, { name = "pydantic" }, @@ -1695,6 +1877,7 @@ requires-dist = [ provides-extras = ["ui"] [package.metadata.requires-dev] +benchmark = [{ name = "locust", specifier = ">=2.37.14" }] codegen = [ { name = "jinja2", specifier = ">=3.1.6" }, { name = "pydantic" }, @@ -1800,6 +1983,47 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/fc/5eccc86b83c5ced3a3bca071d250a86ccafa4ff17546cf781deb7758ab74/llama_stack_client-0.2.17-py3-none-any.whl", hash = "sha256:336c32f8688700ff64717b8109f405dc87a990fbe310c2027ac9ed6d39d67d16", size = 350329, upload-time = "2025-08-05T01:42:54.381Z" }, ] +[[package]] +name = "locust" +version = "2.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "configargparse" }, + { name = "flask" }, + { name = "flask-cors" }, + { name = "flask-login" }, + { name = "gevent" }, + { name = "geventhttpclient" }, + { name = "locust-cloud" }, + { name = "msgpack" }, + { name = "psutil" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "pyzmq" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fb/93/ecd79dde28e24bdc99488d4e2c0ad4117252257d5cbdd61e3b14d1f03786/locust-2.38.0.tar.gz", hash = "sha256:5bd6c29d8423733cb5d9a265548c9fef7b731f2254aa91885d6c98d0d39f90f0", size = 1406518, upload-time = "2025-08-07T10:18:52.584Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/be/57ca67b95c45e69c173e86fe5c934d789effc2ec203d3e3ec2a0b32aa707/locust-2.38.0-py3-none-any.whl", hash = "sha256:b92c937e8659e9ffd6d6d1cab2f63f70aa98c87975911938d1f473534f46fd78", size = 1424083, upload-time = "2025-08-07T10:18:50.499Z" }, +] + +[[package]] +name = "locust-cloud" +version = "1.26.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "configargparse" }, + { name = "gevent" }, + { name = "platformdirs" }, + { name = "python-engineio" }, + { name = "python-socketio", extra = ["client"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/ad/10b299b134068a4250a9156e6832a717406abe1dfea2482a07ae7bdca8f3/locust_cloud-1.26.3.tar.gz", hash = "sha256:587acfd4d2dee715fb5f0c3c2d922770babf0b7cff7b2927afbb693a9cd193cc", size = 456042, upload-time = "2025-07-15T19:51:53.791Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/6a/276fc50a9d170e7cbb6715735480cb037abb526639bca85491576e6eee4a/locust_cloud-1.26.3-py3-none-any.whl", hash = "sha256:8cb4b8bb9adcd5b99327bc8ed1d98cf67a29d9d29512651e6e94869de6f1faa8", size = 410023, upload-time = "2025-07-15T19:51:52.056Z" }, +] + [[package]] name = "lxml" version = "6.0.0" @@ -2017,6 +2241,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, ] +[[package]] +name = "msgpack" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/45/b1/ea4f68038a18c77c9467400d166d74c4ffa536f34761f7983a104357e614/msgpack-1.1.1.tar.gz", hash = "sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd", size = 173555, upload-time = "2025-06-13T06:52:51.324Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/26/389b9c593eda2b8551b2e7126ad3a06af6f9b44274eb3a4f054d48ff7e47/msgpack-1.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ae497b11f4c21558d95de9f64fff7053544f4d1a17731c866143ed6bb4591238", size = 82359, upload-time = "2025-06-13T06:52:03.909Z" }, + { url = "https://files.pythonhosted.org/packages/ab/65/7d1de38c8a22cf8b1551469159d4b6cf49be2126adc2482de50976084d78/msgpack-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:33be9ab121df9b6b461ff91baac6f2731f83d9b27ed948c5b9d1978ae28bf157", size = 79172, upload-time = "2025-06-13T06:52:05.246Z" }, + { url = "https://files.pythonhosted.org/packages/0f/bd/cacf208b64d9577a62c74b677e1ada005caa9b69a05a599889d6fc2ab20a/msgpack-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f64ae8fe7ffba251fecb8408540c34ee9df1c26674c50c4544d72dbf792e5ce", size = 425013, upload-time = "2025-06-13T06:52:06.341Z" }, + { url = "https://files.pythonhosted.org/packages/4d/ec/fd869e2567cc9c01278a736cfd1697941ba0d4b81a43e0aa2e8d71dab208/msgpack-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a494554874691720ba5891c9b0b39474ba43ffb1aaf32a5dac874effb1619e1a", size = 426905, upload-time = "2025-06-13T06:52:07.501Z" }, + { url = "https://files.pythonhosted.org/packages/55/2a/35860f33229075bce803a5593d046d8b489d7ba2fc85701e714fc1aaf898/msgpack-1.1.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb643284ab0ed26f6957d969fe0dd8bb17beb567beb8998140b5e38a90974f6c", size = 407336, upload-time = "2025-06-13T06:52:09.047Z" }, + { url = "https://files.pythonhosted.org/packages/8c/16/69ed8f3ada150bf92745fb4921bd621fd2cdf5a42e25eb50bcc57a5328f0/msgpack-1.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d275a9e3c81b1093c060c3837e580c37f47c51eca031f7b5fb76f7b8470f5f9b", size = 409485, upload-time = "2025-06-13T06:52:10.382Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b6/0c398039e4c6d0b2e37c61d7e0e9d13439f91f780686deb8ee64ecf1ae71/msgpack-1.1.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4fd6b577e4541676e0cc9ddc1709d25014d3ad9a66caa19962c4f5de30fc09ef", size = 412182, upload-time = "2025-06-13T06:52:11.644Z" }, + { url = "https://files.pythonhosted.org/packages/b8/d0/0cf4a6ecb9bc960d624c93effaeaae75cbf00b3bc4a54f35c8507273cda1/msgpack-1.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb29aaa613c0a1c40d1af111abf025f1732cab333f96f285d6a93b934738a68a", size = 419883, upload-time = "2025-06-13T06:52:12.806Z" }, + { url = "https://files.pythonhosted.org/packages/62/83/9697c211720fa71a2dfb632cad6196a8af3abea56eece220fde4674dc44b/msgpack-1.1.1-cp312-cp312-win32.whl", hash = "sha256:870b9a626280c86cff9c576ec0d9cbcc54a1e5ebda9cd26dab12baf41fee218c", size = 65406, upload-time = "2025-06-13T06:52:14.271Z" }, + { url = "https://files.pythonhosted.org/packages/c0/23/0abb886e80eab08f5e8c485d6f13924028602829f63b8f5fa25a06636628/msgpack-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:5692095123007180dca3e788bb4c399cc26626da51629a31d40207cb262e67f4", size = 72558, upload-time = "2025-06-13T06:52:15.252Z" }, + { url = "https://files.pythonhosted.org/packages/a1/38/561f01cf3577430b59b340b51329803d3a5bf6a45864a55f4ef308ac11e3/msgpack-1.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0", size = 81677, upload-time = "2025-06-13T06:52:16.64Z" }, + { url = "https://files.pythonhosted.org/packages/09/48/54a89579ea36b6ae0ee001cba8c61f776451fad3c9306cd80f5b5c55be87/msgpack-1.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9", size = 78603, upload-time = "2025-06-13T06:52:17.843Z" }, + { url = "https://files.pythonhosted.org/packages/a0/60/daba2699b308e95ae792cdc2ef092a38eb5ee422f9d2fbd4101526d8a210/msgpack-1.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8", size = 420504, upload-time = "2025-06-13T06:52:18.982Z" }, + { url = "https://files.pythonhosted.org/packages/20/22/2ebae7ae43cd8f2debc35c631172ddf14e2a87ffcc04cf43ff9df9fff0d3/msgpack-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a", size = 423749, upload-time = "2025-06-13T06:52:20.211Z" }, + { url = "https://files.pythonhosted.org/packages/40/1b/54c08dd5452427e1179a40b4b607e37e2664bca1c790c60c442c8e972e47/msgpack-1.1.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac", size = 404458, upload-time = "2025-06-13T06:52:21.429Z" }, + { url = "https://files.pythonhosted.org/packages/2e/60/6bb17e9ffb080616a51f09928fdd5cac1353c9becc6c4a8abd4e57269a16/msgpack-1.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b", size = 405976, upload-time = "2025-06-13T06:52:22.995Z" }, + { url = "https://files.pythonhosted.org/packages/ee/97/88983e266572e8707c1f4b99c8fd04f9eb97b43f2db40e3172d87d8642db/msgpack-1.1.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7", size = 408607, upload-time = "2025-06-13T06:52:24.152Z" }, + { url = "https://files.pythonhosted.org/packages/bc/66/36c78af2efaffcc15a5a61ae0df53a1d025f2680122e2a9eb8442fed3ae4/msgpack-1.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5", size = 424172, upload-time = "2025-06-13T06:52:25.704Z" }, + { url = "https://files.pythonhosted.org/packages/8c/87/a75eb622b555708fe0427fab96056d39d4c9892b0c784b3a721088c7ee37/msgpack-1.1.1-cp313-cp313-win32.whl", hash = "sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323", size = 65347, upload-time = "2025-06-13T06:52:26.846Z" }, + { url = "https://files.pythonhosted.org/packages/ca/91/7dc28d5e2a11a5ad804cf2b7f7a5fcb1eb5a4966d66a5d2b41aee6376543/msgpack-1.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69", size = 72341, upload-time = "2025-06-13T06:52:27.835Z" }, +] + [[package]] name = "multidict" version = "6.6.3" @@ -3306,6 +3558,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5f/ed/539768cf28c661b5b068d66d96a2f155c4971a5d55684a514c1a0e0dec2f/python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc", size = 20556, upload-time = "2025-06-24T04:21:06.073Z" }, ] +[[package]] +name = "python-engineio" +version = "4.12.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "simple-websocket" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/0b/67295279b66835f9fa7a491650efcd78b20321c127036eef62c11a31e028/python_engineio-4.12.2.tar.gz", hash = "sha256:e7e712ffe1be1f6a05ee5f951e72d434854a32fcfc7f6e4d9d3cae24ec70defa", size = 91677, upload-time = "2025-06-04T19:22:18.789Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" }, +] + [[package]] name = "python-jose" version = "3.5.0" @@ -3334,6 +3598,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size = 24546, upload-time = "2024-12-16T19:45:44.423Z" }, ] +[[package]] +name = "python-socketio" +version = "5.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bidict" }, + { name = "python-engineio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/1a/396d50ccf06ee539fa758ce5623b59a9cb27637fc4b2dc07ed08bf495e77/python_socketio-5.13.0.tar.gz", hash = "sha256:ac4e19a0302ae812e23b712ec8b6427ca0521f7c582d6abb096e36e24a263029", size = 121125, upload-time = "2025-04-12T15:46:59.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/32/b4fb8585d1be0f68bde7e110dffbcf354915f77ad8c778563f0ad9655c02/python_socketio-5.13.0-py3-none-any.whl", hash = "sha256:51f68d6499f2df8524668c24bcec13ba1414117cfb3a90115c559b601ab10caf", size = 77800, upload-time = "2025-04-12T15:46:58.412Z" }, +] + +[package.optional-dependencies] +client = [ + { name = "requests" }, + { name = "websocket-client" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -3726,6 +4009,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, ] +[[package]] +name = "simple-websocket" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/d4/bfa032f961103eba93de583b161f0e6a5b63cebb8f2c7d0c6e6efe1e3d2e/simple_websocket-1.1.0.tar.gz", hash = "sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4", size = 17300, upload-time = "2024-10-10T22:39:31.412Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/59/0782e51887ac6b07ffd1570e0364cf901ebc36345fea669969d2084baebb/simple_websocket-1.1.0-py3-none-any.whl", hash = "sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c", size = 13842, upload-time = "2024-10-10T22:39:29.645Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -4796,6 +5091,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] +[[package]] +name = "werkzeug" +version = "3.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/69/83029f1f6300c5fb2471d621ab06f6ec6b3324685a2ce0f9777fd4a8b71e/werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746", size = 806925, upload-time = "2024-11-08T15:52:18.093Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/24/ab44c871b0f07f491e5d2ad12c9bd7358e527510618cb1b803a88e986db1/werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e", size = 224498, upload-time = "2024-11-08T15:52:16.132Z" }, +] + +[[package]] +name = "wsproto" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/4a/44d3c295350d776427904d73c189e10aeae66d7f555bb2feee16d1e4ba5a/wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065", size = 53425, upload-time = "2022-08-23T19:58:21.447Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" }, +] + [[package]] name = "xxhash" version = "3.5.0" @@ -4907,3 +5226,38 @@ sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50e wheels = [ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, ] + +[[package]] +name = "zope-event" +version = "5.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/9f/c443569a68d3844c044d9fa9711e08adb33649b527b4d432433f4c2a6a02/zope_event-5.1.1.tar.gz", hash = "sha256:c1ac931abf57efba71a2a313c5f4d57768a19b15c37e3f02f50eb1536be12d4e", size = 18811, upload-time = "2025-07-22T07:04:00.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/04/fd55695f6448abd22295fc68b2d3a135389558f0f49a24b0dffe019d0ecb/zope_event-5.1.1-py3-none-any.whl", hash = "sha256:8d5ea7b992c42ce73a6fa9c2ba99a004c52cd9f05d87f3220768ef0329b92df7", size = 7014, upload-time = "2025-07-22T07:03:59.9Z" }, +] + +[[package]] +name = "zope-interface" +version = "7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/93/9210e7606be57a2dfc6277ac97dcc864fd8d39f142ca194fdc186d596fda/zope.interface-7.2.tar.gz", hash = "sha256:8b49f1a3d1ee4cdaf5b32d2e738362c7f5e40ac8b46dd7d1a65e82a4872728fe", size = 252960, upload-time = "2024-11-28T08:45:39.224Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/0b/c7516bc3bad144c2496f355e35bd699443b82e9437aa02d9867653203b4a/zope.interface-7.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:086ee2f51eaef1e4a52bd7d3111a0404081dadae87f84c0ad4ce2649d4f708b7", size = 208959, upload-time = "2024-11-28T08:47:47.788Z" }, + { url = "https://files.pythonhosted.org/packages/a2/e9/1463036df1f78ff8c45a02642a7bf6931ae4a38a4acd6a8e07c128e387a7/zope.interface-7.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:21328fcc9d5b80768bf051faa35ab98fb979080c18e6f84ab3f27ce703bce465", size = 209357, upload-time = "2024-11-28T08:47:50.897Z" }, + { url = "https://files.pythonhosted.org/packages/07/a8/106ca4c2add440728e382f1b16c7d886563602487bdd90004788d45eb310/zope.interface-7.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6dd02ec01f4468da0f234da9d9c8545c5412fef80bc590cc51d8dd084138a89", size = 264235, upload-time = "2024-11-28T09:18:15.56Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ca/57286866285f4b8a4634c12ca1957c24bdac06eae28fd4a3a578e30cf906/zope.interface-7.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e7da17f53e25d1a3bde5da4601e026adc9e8071f9f6f936d0fe3fe84ace6d54", size = 259253, upload-time = "2024-11-28T08:48:29.025Z" }, + { url = "https://files.pythonhosted.org/packages/96/08/2103587ebc989b455cf05e858e7fbdfeedfc3373358320e9c513428290b1/zope.interface-7.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cab15ff4832580aa440dc9790b8a6128abd0b88b7ee4dd56abacbc52f212209d", size = 264702, upload-time = "2024-11-28T08:48:37.363Z" }, + { url = "https://files.pythonhosted.org/packages/5f/c7/3c67562e03b3752ba4ab6b23355f15a58ac2d023a6ef763caaca430f91f2/zope.interface-7.2-cp312-cp312-win_amd64.whl", hash = "sha256:29caad142a2355ce7cfea48725aa8bcf0067e2b5cc63fcf5cd9f97ad12d6afb5", size = 212466, upload-time = "2024-11-28T08:49:14.397Z" }, + { url = "https://files.pythonhosted.org/packages/c6/3b/e309d731712c1a1866d61b5356a069dd44e5b01e394b6cb49848fa2efbff/zope.interface-7.2-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:3e0350b51e88658d5ad126c6a57502b19d5f559f6cb0a628e3dc90442b53dd98", size = 208961, upload-time = "2024-11-28T08:48:29.865Z" }, + { url = "https://files.pythonhosted.org/packages/49/65/78e7cebca6be07c8fc4032bfbb123e500d60efdf7b86727bb8a071992108/zope.interface-7.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:15398c000c094b8855d7d74f4fdc9e73aa02d4d0d5c775acdef98cdb1119768d", size = 209356, upload-time = "2024-11-28T08:48:33.297Z" }, + { url = "https://files.pythonhosted.org/packages/11/b1/627384b745310d082d29e3695db5f5a9188186676912c14b61a78bbc6afe/zope.interface-7.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:802176a9f99bd8cc276dcd3b8512808716492f6f557c11196d42e26c01a69a4c", size = 264196, upload-time = "2024-11-28T09:18:17.584Z" }, + { url = "https://files.pythonhosted.org/packages/b8/f6/54548df6dc73e30ac6c8a7ff1da73ac9007ba38f866397091d5a82237bd3/zope.interface-7.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb23f58a446a7f09db85eda09521a498e109f137b85fb278edb2e34841055398", size = 259237, upload-time = "2024-11-28T08:48:31.71Z" }, + { url = "https://files.pythonhosted.org/packages/b6/66/ac05b741c2129fdf668b85631d2268421c5cd1a9ff99be1674371139d665/zope.interface-7.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a71a5b541078d0ebe373a81a3b7e71432c61d12e660f1d67896ca62d9628045b", size = 264696, upload-time = "2024-11-28T08:48:41.161Z" }, + { url = "https://files.pythonhosted.org/packages/0a/2f/1bccc6f4cc882662162a1158cda1a7f616add2ffe322b28c99cb031b4ffc/zope.interface-7.2-cp313-cp313-win_amd64.whl", hash = "sha256:4893395d5dd2ba655c38ceb13014fd65667740f09fa5bb01caa1e6284e48c0cd", size = 212472, upload-time = "2024-11-28T08:49:56.587Z" }, +] From 5b312a80b99beebb87e5414fb28940806aa5d347 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 13 Aug 2025 11:23:27 -0700 Subject: [PATCH 21/45] feat(responses): improve streaming for function calls (#3124) Emit streaming events for function calls ## Test Plan Improved the test case --- .../agents/meta_reference/openai_responses.py | 152 +++++++++++++++--- .../non_ci/responses/test_responses.py | 113 ++++++++++++- .../meta_reference/test_openai_responses.py | 18 ++- 3 files changed, 250 insertions(+), 33 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 347954908..104f15010 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -33,6 +33,10 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObjectStream, OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCreated, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, + OpenAIResponseObjectStreamResponseOutputItemAdded, + OpenAIResponseObjectStreamResponseOutputItemDone, OpenAIResponseObjectStreamResponseOutputTextDelta, OpenAIResponseOutput, OpenAIResponseOutputMessageContent, @@ -73,7 +77,9 @@ from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition -from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool +from llama_stack.providers.utils.inference.openai_compat import ( + convert_tooldef_to_openai_tool, +) from llama_stack.providers.utils.responses.responses_store import ResponsesStore logger = get_logger(name=__name__, category="openai_responses") @@ -82,7 +88,7 @@ OPENAI_RESPONSES_PREFIX = "openai_responses:" async def _convert_response_content_to_chat_content( - content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent], + content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), ) -> str | list[OpenAIChatCompletionContentPartParam]: """ Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. @@ -150,7 +156,9 @@ async def _convert_response_input_to_chat_messages( return messages -async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: +async def _convert_chat_choice_to_response_message( + choice: OpenAIChoice, +) -> OpenAIResponseMessage: """ Convert an OpenAI Chat Completion choice into an OpenAI Response output message. """ @@ -172,7 +180,9 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open ) -async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam: +async def _convert_response_text_to_chat_response_format( + text: OpenAIResponseText, +) -> OpenAIResponseFormatParam: """ Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. """ @@ -228,7 +238,9 @@ class OpenAIResponsesImpl: self.vector_io_api = vector_io_api async def _prepend_previous_response( - self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None + self, + input: str | list[OpenAIResponseInput], + previous_response_id: str | None = None, ): if previous_response_id: previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) @@ -446,6 +458,8 @@ class OpenAIResponsesImpl: # Create a placeholder message item for delta events message_item_id = f"msg_{uuid.uuid4()}" + # Track tool call items for streaming events + tool_call_item_ids: dict[int, str] = {} async for chunk in completion_result: chat_response_id = chunk.id @@ -472,18 +486,62 @@ class OpenAIResponsesImpl: if chunk_choice.delta.tool_calls: for tool_call in chunk_choice.delta.tool_calls: response_tool_call = chat_response_tool_calls.get(tool_call.index, None) - if response_tool_call: - # Don't attempt to concatenate arguments if we don't have any new argumentsAdd commentMore actions - if tool_call.function.arguments: - # Guard against an initial None argument before we concatenate - response_tool_call.function.arguments = ( - response_tool_call.function.arguments or "" - ) + tool_call.function.arguments - else: + # Create new tool call entry if this is the first chunk for this index + is_new_tool_call = response_tool_call is None + if is_new_tool_call: tool_call_dict: dict[str, Any] = tool_call.model_dump() tool_call_dict.pop("type", None) response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) - chat_response_tool_calls[tool_call.index] = response_tool_call + chat_response_tool_calls[tool_call.index] = response_tool_call + + # Create item ID for this tool call for streaming events + tool_call_item_id = f"fc_{uuid.uuid4()}" + tool_call_item_ids[tool_call.index] = tool_call_item_id + + # Emit output_item.added event for the new function call + sequence_number += 1 + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments="", # Will be filled incrementally via delta events + call_id=tool_call.id or "", + name=tool_call.function.name if tool_call.function else "", + id=tool_call_item_id, + status="in_progress", + ) + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=response_id, + item=function_call_item, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + + # Stream function call arguments as they arrive + if tool_call.function and tool_call.function.arguments: + tool_call_item_id = tool_call_item_ids[tool_call.index] + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + + # Accumulate arguments for final response (only for subsequent chunks) + if not is_new_tool_call: + response_tool_call.function.arguments = ( + response_tool_call.function.arguments or "" + ) + tool_call.function.arguments + + # Emit function_call_arguments.done events for completed tool calls + for tool_call_index in sorted(chat_response_tool_calls.keys()): + tool_call_item_id = tool_call_item_ids[tool_call_index] + final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or "" + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone( + arguments=final_arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) # Convert collected chunks to complete response if chat_response_tool_calls: @@ -532,18 +590,56 @@ class OpenAIResponsesImpl: tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx) if tool_call_log: output_messages.append(tool_call_log) + + # Emit output_item.done event for completed non-function tool call + # Find the item_id for this tool call + matching_item_id = None + for index, item_id in tool_call_item_ids.items(): + response_tool_call = chat_response_tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + if matching_item_id: + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=response_id, + item=tool_call_log, + output_index=len(output_messages) - 1, + sequence_number=sequence_number, + ) + if tool_response_message: next_turn_messages.append(tool_response_message) for tool_call in function_tool_calls: - output_messages.append( - OpenAIResponseOutputMessageFunctionToolCall( - arguments=tool_call.function.arguments or "", - call_id=tool_call.id, - name=tool_call.function.name or "", - id=f"fc_{uuid.uuid4()}", - status="completed", - ) + # Find the item_id for this tool call from our tracking dictionary + matching_item_id = None + for index, item_id in tool_call_item_ids.items(): + response_tool_call = chat_response_tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use existing item_id or create new one if not found + final_item_id = matching_item_id or f"fc_{uuid.uuid4()}" + + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments=tool_call.function.arguments or "", + call_id=tool_call.id, + name=tool_call.function.name or "", + id=final_item_id, + status="completed", + ) + output_messages.append(function_call_item) + + # Emit output_item.done event for completed function call + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=response_id, + item=function_call_item, + output_index=len(output_messages) - 1, + sequence_number=sequence_number, ) if not function_tool_calls and not non_function_tool_calls: @@ -779,7 +875,8 @@ class OpenAIResponsesImpl: ) elif function.name == "knowledge_search": response_file_search_tool = next( - (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), None + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + None, ) if response_file_search_tool: # Use vector_stores.search API instead of knowledge_search tool @@ -798,7 +895,9 @@ class OpenAIResponsesImpl: error_exc = e if function.name in ctx.mcp_tool_to_server: - from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall + from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseOutputMessageMCPCall, + ) message = OpenAIResponseOutputMessageMCPCall( id=tool_call_id, @@ -850,7 +949,10 @@ class OpenAIResponsesImpl: if isinstance(result.content, str): content = result.content elif isinstance(result.content, list): - from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem + from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, + ) content = [] for item in result.content: diff --git a/tests/integration/non_ci/responses/test_responses.py b/tests/integration/non_ci/responses/test_responses.py index 39d00f328..6092346b0 100644 --- a/tests/integration/non_ci/responses/test_responses.py +++ b/tests/integration/non_ci/responses/test_responses.py @@ -384,12 +384,18 @@ def test_response_non_streaming_mcp_tool(request, compat_client, text_model_id, assert list_tools.type == "mcp_list_tools" assert list_tools.server_label == "localmcp" assert len(list_tools.tools) == 2 - assert {t.name for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"} + assert {t.name for t in list_tools.tools} == { + "get_boiling_point", + "greet_everyone", + } call = response.output[1] assert call.type == "mcp_call" assert call.name == "get_boiling_point" - assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True} + assert json.loads(call.arguments) == { + "liquid_name": "myawesomeliquid", + "celsius": True, + } assert call.error is None assert "-100" in call.output @@ -581,6 +587,105 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ f"Last chunk should be response.completed, got {chunks[-1].type}" ) + # Verify tool call streaming events are present + chunk_types = [chunk.type for chunk in chunks] + + # Should have function call arguments delta events for tool calls + delta_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.delta"] + done_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.done"] + + # Should have output item events for tool calls + item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"] + item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"] + + # Verify we have substantial streaming activity (not just batch events) + assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks" + + # Since this test involves MCP tool calls, we should see streaming events + assert len(delta_events) > 0, f"Expected function_call_arguments.delta events, got chunk types: {chunk_types}" + assert len(done_events) > 0, f"Expected function_call_arguments.done events, got chunk types: {chunk_types}" + + # Should have output item events for function calls + assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}" + assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}" + + # Verify delta events have proper structure + for delta_event in delta_events: + assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field" + assert hasattr(delta_event, "item_id"), "Delta event should have 'item_id' field" + assert hasattr(delta_event, "sequence_number"), "Delta event should have 'sequence_number' field" + assert delta_event.delta, "Delta should not be empty" + + # Verify done events have proper structure + for done_event in done_events: + assert hasattr(done_event, "arguments"), "Done event should have 'arguments' field" + assert hasattr(done_event, "item_id"), "Done event should have 'item_id' field" + assert done_event.arguments, "Final arguments should not be empty" + + # Verify output item added events have proper structure + for added_event in item_added_events: + assert hasattr(added_event, "item"), "Added event should have 'item' field" + assert hasattr(added_event, "output_index"), "Added event should have 'output_index' field" + assert hasattr(added_event, "sequence_number"), "Added event should have 'sequence_number' field" + assert hasattr(added_event, "response_id"), "Added event should have 'response_id' field" + assert added_event.item.type in ["function_call", "mcp_call"], "Added item should be a tool call" + assert added_event.item.status == "in_progress", "Added item should be in progress" + assert added_event.response_id, "Response ID should not be empty" + assert isinstance(added_event.output_index, int), "Output index should be integer" + assert added_event.output_index >= 0, "Output index should be non-negative" + + # Verify output item done events have proper structure + for done_event in item_done_events: + assert hasattr(done_event, "item"), "Done event should have 'item' field" + assert hasattr(done_event, "output_index"), "Done event should have 'output_index' field" + assert hasattr(done_event, "sequence_number"), "Done event should have 'sequence_number' field" + assert hasattr(done_event, "response_id"), "Done event should have 'response_id' field" + assert done_event.item.type in ["function_call", "mcp_call"], "Done item should be a tool call" + # Note: MCP calls don't have a status field, only function calls do + if done_event.item.type == "function_call": + assert done_event.item.status == "completed", "Function call should be completed" + assert done_event.response_id, "Response ID should not be empty" + assert isinstance(done_event.output_index, int), "Output index should be integer" + assert done_event.output_index >= 0, "Output index should be non-negative" + + # Group function call argument events by item_id (these should have proper tracking) + function_call_events_by_item_id = {} + for chunk in chunks: + if hasattr(chunk, "item_id") and chunk.type in [ + "response.function_call_arguments.delta", + "response.function_call_arguments.done", + ]: + item_id = chunk.item_id + if item_id not in function_call_events_by_item_id: + function_call_events_by_item_id[item_id] = [] + function_call_events_by_item_id[item_id].append(chunk) + + for item_id, related_events in function_call_events_by_item_id.items(): + # Should have at least one delta and one done event for a complete function call + delta_events = [e for e in related_events if e.type == "response.function_call_arguments.delta"] + done_events = [e for e in related_events if e.type == "response.function_call_arguments.done"] + + assert len(delta_events) > 0, f"Item {item_id} should have at least one delta event" + assert len(done_events) == 1, f"Item {item_id} should have exactly one done event" + + # Verify all events have the same item_id + for event in related_events: + assert event.item_id == item_id, f"Event should have consistent item_id {item_id}, got {event.item_id}" + + # Basic pairing check: each output_item.added should be followed by some activity + # (but we can't enforce strict 1:1 pairing due to the complexity of multi-turn scenarios) + assert len(item_added_events) > 0, "Should have at least one output_item.added event" + + # Verify response_id consistency across all events + response_ids = set() + for chunk in chunks: + if hasattr(chunk, "response_id"): + response_ids.add(chunk.response_id) + elif hasattr(chunk, "response") and hasattr(chunk.response, "id"): + response_ids.add(chunk.response.id) + + assert len(response_ids) == 1, f"All events should reference the same response_id, found: {response_ids}" + # Get the final response from the last chunk final_chunk = chunks[-1] if hasattr(final_chunk, "response"): @@ -722,7 +827,9 @@ def vector_store_with_filtered_files(compat_client, text_model_id, tmp_path_fact # Attach file to vector store with attributes file_attach_response = compat_client.vector_stores.files.create( - vector_store_id=vector_store.id, file_id=file_response.id, attributes=file_data["attributes"] + vector_store_id=vector_store.id, + file_id=file_response.id, + attributes=file_data["attributes"], ) # Wait for attachment diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 2ab5b557e..855a525e9 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -272,7 +272,9 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_ # Check that we got the content from our mocked tool execution result chunks = [chunk async for chunk in result] - assert len(chunks) == 2 # Should have response.created and response.completed + # Should have: response.created, output_item.added, function_call_arguments.delta, + # function_call_arguments.done, output_item.done, response.completed + assert len(chunks) == 6 # Verify inference API was called correctly (after iterating over result) first_call = mock_inference_api.openai_chat_completion.call_args_list[0] @@ -284,11 +286,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_ assert chunks[0].type == "response.created" assert len(chunks[0].response.output) == 0 + # Check streaming events + assert chunks[1].type == "response.output_item.added" + assert chunks[2].type == "response.function_call_arguments.delta" + assert chunks[3].type == "response.function_call_arguments.done" + assert chunks[4].type == "response.output_item.done" + # Check response.completed event (should have the tool call) - assert chunks[1].type == "response.completed" - assert len(chunks[1].response.output) == 1 - assert chunks[1].response.output[0].type == "function_call" - assert chunks[1].response.output[0].name == "get_weather" + assert chunks[5].type == "response.completed" + assert len(chunks[5].response.output) == 1 + assert chunks[5].response.output[0].type == "function_call" + assert chunks[5].response.output[0].name == "get_weather" async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api): From 8638537d14f0dc4a0b3a1acdaa295894f205b83f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 13 Aug 2025 16:31:25 -0700 Subject: [PATCH 22/45] feat(responses): stream progress of tool calls (#3135) # What does this PR do? Enhances tool execution streaming by adding support for real-time progress events during tool calls. This implementation adds streaming events for MCP and web search tools, including in-progress, searching, completed, and failed states. The refactored `_execute_tool_call` method now returns an async iterator that yields streaming events throughout the tool execution lifecycle. ## Test Plan Updated the integration test `test_response_streaming_multi_turn_tool_execution` to verify the presence and structure of new streaming events, including: - Checking for MCP in-progress and completed events - Verifying that progress events contain required fields (item_id, output_index, sequence_number) - Ensuring completed events have the necessary sequence_number field --- .../agents/meta_reference/openai_responses.py | 137 +++++++++++++++--- .../non_ci/responses/test_responses.py | 22 +++ 2 files changed, 141 insertions(+), 18 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 104f15010..fbb5a608a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -35,9 +35,15 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpCallCompleted, + OpenAIResponseObjectStreamResponseMcpCallFailed, + OpenAIResponseObjectStreamResponseMcpCallInProgress, OpenAIResponseObjectStreamResponseOutputItemAdded, OpenAIResponseObjectStreamResponseOutputItemDone, OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseObjectStreamResponseWebSearchCallCompleted, + OpenAIResponseObjectStreamResponseWebSearchCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallSearching, OpenAIResponseOutput, OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, @@ -87,6 +93,15 @@ logger = get_logger(name=__name__, category="openai_responses") OPENAI_RESPONSES_PREFIX = "openai_responses:" +class ToolExecutionResult(BaseModel): + """Result of streaming tool execution.""" + + stream_event: OpenAIResponseObjectStream | None = None + sequence_number: int + final_output_message: OpenAIResponseOutput | None = None + final_input_message: OpenAIMessageParam | None = None + + async def _convert_response_content_to_chat_content( content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), ) -> str | list[OpenAIChatCompletionContentPartParam]: @@ -587,19 +602,38 @@ class OpenAIResponsesImpl: # execute non-function tool calls for tool_call in non_function_tool_calls: - tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx) + # Find the item_id for this tool call + matching_item_id = None + for index, item_id in tool_call_item_ids.items(): + response_tool_call = chat_response_tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use a fallback item_id if not found + if not matching_item_id: + matching_item_id = f"tc_{uuid.uuid4()}" + + # Execute tool call with streaming + tool_call_log = None + tool_response_message = None + async for result in self._execute_tool_call( + tool_call, ctx, sequence_number, response_id, len(output_messages), matching_item_id + ): + if result.stream_event: + # Forward streaming events + sequence_number = result.sequence_number + yield result.stream_event + + if result.final_output_message is not None: + tool_call_log = result.final_output_message + tool_response_message = result.final_input_message + sequence_number = result.sequence_number + if tool_call_log: output_messages.append(tool_call_log) # Emit output_item.done event for completed non-function tool call - # Find the item_id for this tool call - matching_item_id = None - for index, item_id in tool_call_item_ids.items(): - response_tool_call = chat_response_tool_calls.get(index) - if response_tool_call and response_tool_call.id == tool_call.id: - matching_item_id = item_id - break - if matching_item_id: sequence_number += 1 yield OpenAIResponseObjectStreamResponseOutputItemDone( @@ -848,7 +882,11 @@ class OpenAIResponsesImpl: self, tool_call: OpenAIChatCompletionToolCall, ctx: ChatCompletionContext, - ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: + sequence_number: int, + response_id: str, + output_index: int, + item_id: str, + ) -> AsyncIterator[ToolExecutionResult]: from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) @@ -858,8 +896,41 @@ class OpenAIResponsesImpl: tool_kwargs = json.loads(function.arguments) if function.arguments else {} if not function or not tool_call_id or not function.name: - return None, None + yield ToolExecutionResult(sequence_number=sequence_number) + return + # Emit in_progress event based on tool type (only for tools with specific streaming events) + progress_event = None + if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + elif function.name == "web_search": + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec + + if progress_event: + yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) + + # For web search, emit searching event + if function.name == "web_search": + sequence_number += 1 + searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) + + # Execute the actual tool call error_exc = None result = None try: @@ -894,6 +965,33 @@ class OpenAIResponsesImpl: except Exception as e: error_exc = e + # Emit completion or failure event based on result (only for tools with specific streaming events) + has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + completion_event = None + + if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: + sequence_number += 1 + if has_error: + completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + sequence_number=sequence_number, + ) + else: + completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + sequence_number=sequence_number, + ) + elif function.name == "web_search": + sequence_number += 1 + completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec + + if completion_event: + yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + + # Build the result message and input message if function.name in ctx.mcp_tool_to_server: from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageMCPCall, @@ -907,9 +1005,9 @@ class OpenAIResponsesImpl: ) if error_exc: message.error = str(error_exc) - elif (result.error_code and result.error_code > 0) or result.error_message: + elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result.content: + elif result and result.content: message.output = interleaved_content_as_str(result.content) else: if function.name == "web_search": @@ -917,7 +1015,7 @@ class OpenAIResponsesImpl: id=tool_call_id, status="completed", ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + if has_error: message.status = "failed" elif function.name == "knowledge_search": message = OpenAIResponseOutputMessageFileSearchToolCall( @@ -925,7 +1023,7 @@ class OpenAIResponsesImpl: queries=[tool_kwargs.get("query", "")], status="completed", ) - if "document_ids" in result.metadata: + if result and "document_ids" in result.metadata: message.results = [] for i, doc_id in enumerate(result.metadata["document_ids"]): text = result.metadata["chunks"][i] if "chunks" in result.metadata else None @@ -939,7 +1037,7 @@ class OpenAIResponsesImpl: attributes={}, ) ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + if has_error: message.status = "failed" else: raise ValueError(f"Unknown tool {function.name} called") @@ -971,10 +1069,13 @@ class OpenAIResponsesImpl: raise ValueError(f"Unknown result content type: {type(result.content)}") input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) else: - text = str(error_exc) + text = str(error_exc) if error_exc else "Tool execution failed" input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) - return message, input_message + # Yield the final result + yield ToolExecutionResult( + sequence_number=sequence_number, final_output_message=message, final_input_message=input_message + ) def _is_function_tool_call( diff --git a/tests/integration/non_ci/responses/test_responses.py b/tests/integration/non_ci/responses/test_responses.py index 6092346b0..776e3cf30 100644 --- a/tests/integration/non_ci/responses/test_responses.py +++ b/tests/integration/non_ci/responses/test_responses.py @@ -598,6 +598,10 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"] item_done_events = [chunk for chunk in chunks if chunk.type == "response.output_item.done"] + # Should have tool execution progress events + mcp_in_progress_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.in_progress"] + mcp_completed_events = [chunk for chunk in chunks if chunk.type == "response.mcp_call.completed"] + # Verify we have substantial streaming activity (not just batch events) assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks" @@ -609,6 +613,24 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}" assert len(item_done_events) > 0, f"Expected response.output_item.done events, got chunk types: {chunk_types}" + # Should have tool execution progress events + assert len(mcp_in_progress_events) > 0, ( + f"Expected response.mcp_call.in_progress events, got chunk types: {chunk_types}" + ) + assert len(mcp_completed_events) > 0, ( + f"Expected response.mcp_call.completed events, got chunk types: {chunk_types}" + ) + # MCP failed events are optional (only if errors occur) + + # Verify progress events have proper structure + for progress_event in mcp_in_progress_events: + assert hasattr(progress_event, "item_id"), "Progress event should have 'item_id' field" + assert hasattr(progress_event, "output_index"), "Progress event should have 'output_index' field" + assert hasattr(progress_event, "sequence_number"), "Progress event should have 'sequence_number' field" + + for completed_event in mcp_completed_events: + assert hasattr(completed_event, "sequence_number"), "Completed event should have 'sequence_number' field" + # Verify delta events have proper structure for delta_event in delta_events: assert hasattr(delta_event, "delta"), "Delta event should have 'delta' field" From e1e161553c323e2477f24c7091a74cb51f18ef78 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 13 Aug 2025 16:34:26 -0700 Subject: [PATCH 23/45] feat(responses): add MCP argument streaming and content part events (#3136) # What does this PR do? Adds content part streaming events to the OpenAI-compatible Responses API to support more granular streaming of response content. This introduces: 1. New schema types for content parts: `OpenAIResponseContentPart` with variants for text output and refusals 2. New streaming event types: - `OpenAIResponseObjectStreamResponseContentPartAdded` for when content parts begin - `OpenAIResponseObjectStreamResponseContentPartDone` for when content parts complete 3. Implementation in the reference provider to emit these events during streaming responses. Also emits MCP arguments just like function call ones. ## Test Plan Updated existing streaming tests to verify content part events are properly emitted --- docs/_static/llama-stack-spec.html | 137 ++++++++++++++++++ docs/_static/llama-stack-spec.yaml | 111 ++++++++++++++ llama_stack/apis/agents/openai_responses.py | 58 ++++++++ .../agents/meta_reference/openai_responses.py | 96 ++++++++++-- .../non_ci/responses/test_responses.py | 77 ++++++++-- .../meta_reference/test_openai_responses.py | 36 ++++- 6 files changed, 480 insertions(+), 35 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 25f916d87..0549dda21 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -8821,6 +8821,61 @@ "title": "OpenAIResponseOutputMessageMCPListTools", "description": "MCP list tools output message containing available tools from an MCP server." }, + "OpenAIResponseContentPart": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIResponseContentPartOutputText" + }, + { + "$ref": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "output_text": "#/components/schemas/OpenAIResponseContentPartOutputText", + "refusal": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + } + }, + "OpenAIResponseContentPartOutputText": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "output_text", + "default": "output_text" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "text" + ], + "title": "OpenAIResponseContentPartOutputText" + }, + "OpenAIResponseContentPartRefusal": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "refusal", + "default": "refusal" + }, + "refusal": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "refusal" + ], + "title": "OpenAIResponseContentPartRefusal" + }, "OpenAIResponseObjectStream": { "oneOf": [ { @@ -8877,6 +8932,12 @@ { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted" }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded" + }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone" + }, { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } @@ -8902,6 +8963,8 @@ "response.mcp_call.in_progress": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress", "response.mcp_call.failed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed", "response.mcp_call.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted", + "response.content_part.added": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded", + "response.content_part.done": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone", "response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } } @@ -8928,6 +8991,80 @@ "title": "OpenAIResponseObjectStreamResponseCompleted", "description": "Streaming event indicating a response has been completed." }, + "OpenAIResponseObjectStreamResponseContentPartAdded": { + "type": "object", + "properties": { + "response_id": { + "type": "string", + "description": "Unique identifier of the response containing this content" + }, + "item_id": { + "type": "string", + "description": "Unique identifier of the output item containing this content part" + }, + "part": { + "$ref": "#/components/schemas/OpenAIResponseContentPart", + "description": "The content part that was added" + }, + "sequence_number": { + "type": "integer", + "description": "Sequential number for ordering streaming events" + }, + "type": { + "type": "string", + "const": "response.content_part.added", + "default": "response.content_part.added", + "description": "Event type identifier, always \"response.content_part.added\"" + } + }, + "additionalProperties": false, + "required": [ + "response_id", + "item_id", + "part", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseContentPartAdded", + "description": "Streaming event for when a new content part is added to a response item." + }, + "OpenAIResponseObjectStreamResponseContentPartDone": { + "type": "object", + "properties": { + "response_id": { + "type": "string", + "description": "Unique identifier of the response containing this content" + }, + "item_id": { + "type": "string", + "description": "Unique identifier of the output item containing this content part" + }, + "part": { + "$ref": "#/components/schemas/OpenAIResponseContentPart", + "description": "The completed content part" + }, + "sequence_number": { + "type": "integer", + "description": "Sequential number for ordering streaming events" + }, + "type": { + "type": "string", + "const": "response.content_part.done", + "default": "response.content_part.done", + "description": "Event type identifier, always \"response.content_part.done\"" + } + }, + "additionalProperties": false, + "required": [ + "response_id", + "item_id", + "part", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseContentPartDone", + "description": "Streaming event for when a content part is completed." + }, "OpenAIResponseObjectStreamResponseCreated": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 43e9fa95a..aa47cd58d 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6441,6 +6441,43 @@ components: title: OpenAIResponseOutputMessageMCPListTools description: >- MCP list tools output message containing available tools from an MCP server. + OpenAIResponseContentPart: + oneOf: + - $ref: '#/components/schemas/OpenAIResponseContentPartOutputText' + - $ref: '#/components/schemas/OpenAIResponseContentPartRefusal' + discriminator: + propertyName: type + mapping: + output_text: '#/components/schemas/OpenAIResponseContentPartOutputText' + refusal: '#/components/schemas/OpenAIResponseContentPartRefusal' + OpenAIResponseContentPartOutputText: + type: object + properties: + type: + type: string + const: output_text + default: output_text + text: + type: string + additionalProperties: false + required: + - type + - text + title: OpenAIResponseContentPartOutputText + OpenAIResponseContentPartRefusal: + type: object + properties: + type: + type: string + const: refusal + default: refusal + refusal: + type: string + additionalProperties: false + required: + - type + - refusal + title: OpenAIResponseContentPartRefusal OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' @@ -6461,6 +6498,8 @@ components: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' discriminator: propertyName: type @@ -6483,6 +6522,8 @@ components: response.mcp_call.in_progress: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress' response.mcp_call.failed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed' response.mcp_call.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted' + response.content_part.added: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded' + response.content_part.done: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone' response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' "OpenAIResponseObjectStreamResponseCompleted": type: object @@ -6504,6 +6545,76 @@ components: OpenAIResponseObjectStreamResponseCompleted description: >- Streaming event indicating a response has been completed. + "OpenAIResponseObjectStreamResponseContentPartAdded": + type: object + properties: + response_id: + type: string + description: >- + Unique identifier of the response containing this content + item_id: + type: string + description: >- + Unique identifier of the output item containing this content part + part: + $ref: '#/components/schemas/OpenAIResponseContentPart' + description: The content part that was added + sequence_number: + type: integer + description: >- + Sequential number for ordering streaming events + type: + type: string + const: response.content_part.added + default: response.content_part.added + description: >- + Event type identifier, always "response.content_part.added" + additionalProperties: false + required: + - response_id + - item_id + - part + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseContentPartAdded + description: >- + Streaming event for when a new content part is added to a response item. + "OpenAIResponseObjectStreamResponseContentPartDone": + type: object + properties: + response_id: + type: string + description: >- + Unique identifier of the response containing this content + item_id: + type: string + description: >- + Unique identifier of the output item containing this content part + part: + $ref: '#/components/schemas/OpenAIResponseContentPart' + description: The completed content part + sequence_number: + type: integer + description: >- + Sequential number for ordering streaming events + type: + type: string + const: response.content_part.done + default: response.content_part.done + description: >- + Event type identifier, always "response.content_part.done" + additionalProperties: false + required: + - response_id + - item_id + - part + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseContentPartDone + description: >- + Streaming event for when a content part is completed. "OpenAIResponseObjectStreamResponseCreated": type: object properties: diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 8574104dc..591992479 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -623,6 +623,62 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel): type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed" +@json_schema_type +class OpenAIResponseContentPartOutputText(BaseModel): + type: Literal["output_text"] = "output_text" + text: str + # TODO: add annotations, logprobs, etc. + + +@json_schema_type +class OpenAIResponseContentPartRefusal(BaseModel): + type: Literal["refusal"] = "refusal" + refusal: str + + +OpenAIResponseContentPart = Annotated[ + OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal, + Field(discriminator="type"), +] +register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart") + + +@json_schema_type +class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel): + """Streaming event for when a new content part is added to a response item. + + :param response_id: Unique identifier of the response containing this content + :param item_id: Unique identifier of the output item containing this content part + :param part: The content part that was added + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.content_part.added" + """ + + response_id: str + item_id: str + part: OpenAIResponseContentPart + sequence_number: int + type: Literal["response.content_part.added"] = "response.content_part.added" + + +@json_schema_type +class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel): + """Streaming event for when a content part is completed. + + :param response_id: Unique identifier of the response containing this content + :param item_id: Unique identifier of the output item containing this content part + :param part: The completed content part + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.content_part.done" + """ + + response_id: str + item_id: str + part: OpenAIResponseContentPart + sequence_number: int + type: Literal["response.content_part.done"] = "response.content_part.done" + + OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseOutputItemAdded @@ -642,6 +698,8 @@ OpenAIResponseObjectStream = Annotated[ | OpenAIResponseObjectStreamResponseMcpCallInProgress | OpenAIResponseObjectStreamResponseMcpCallFailed | OpenAIResponseObjectStreamResponseMcpCallCompleted + | OpenAIResponseObjectStreamResponseContentPartAdded + | OpenAIResponseObjectStreamResponseContentPartDone | OpenAIResponseObjectStreamResponseCompleted, Field(discriminator="type"), ] diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index fbb5a608a..6aca4d68e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -20,6 +20,7 @@ from llama_stack.apis.agents.openai_responses import ( ListOpenAIResponseInputItem, ListOpenAIResponseObject, OpenAIDeleteResponseObject, + OpenAIResponseContentPartOutputText, OpenAIResponseInput, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputMessageContent, @@ -32,9 +33,13 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectStream, OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseContentPartDone, OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone, OpenAIResponseObjectStreamResponseMcpCallCompleted, OpenAIResponseObjectStreamResponseMcpCallFailed, OpenAIResponseObjectStreamResponseMcpCallInProgress, @@ -475,6 +480,8 @@ class OpenAIResponsesImpl: message_item_id = f"msg_{uuid.uuid4()}" # Track tool call items for streaming events tool_call_item_ids: dict[int, str] = {} + # Track content parts for streaming events + content_part_emitted = False async for chunk in completion_result: chat_response_id = chunk.id @@ -483,6 +490,18 @@ class OpenAIResponsesImpl: for chunk_choice in chunk.choices: # Emit incremental text content as delta events if chunk_choice.delta.content: + # Emit content_part.added event for first text chunk + if not content_part_emitted: + content_part_emitted = True + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartAdded( + response_id=response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text="", # Will be filled incrementally via text deltas + ), + sequence_number=sequence_number, + ) sequence_number += 1 yield OpenAIResponseObjectStreamResponseOutputTextDelta( content_index=0, @@ -529,16 +548,33 @@ class OpenAIResponsesImpl: sequence_number=sequence_number, ) - # Stream function call arguments as they arrive + # Stream tool call arguments as they arrive (differentiate between MCP and function calls) if tool_call.function and tool_call.function.arguments: tool_call_item_id = tool_call_item_ids[tool_call.index] sequence_number += 1 - yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta( - delta=tool_call.function.arguments, - item_id=tool_call_item_id, - output_index=len(output_messages), - sequence_number=sequence_number, + + # Check if this is an MCP tool call + is_mcp_tool = ( + ctx.mcp_tool_to_server + and tool_call.function.name + and tool_call.function.name in ctx.mcp_tool_to_server ) + if is_mcp_tool: + # Emit MCP-specific argument delta event + yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + else: + # Emit function call argument delta event + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) # Accumulate arguments for final response (only for subsequent chunks) if not is_new_tool_call: @@ -546,27 +582,55 @@ class OpenAIResponsesImpl: response_tool_call.function.arguments or "" ) + tool_call.function.arguments - # Emit function_call_arguments.done events for completed tool calls + # Emit arguments.done events for completed tool calls (differentiate between MCP and function calls) for tool_call_index in sorted(chat_response_tool_calls.keys()): tool_call_item_id = tool_call_item_ids[tool_call_index] final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or "" + tool_call_name = chat_response_tool_calls[tool_call_index].function.name + + # Check if this is an MCP tool call + is_mcp_tool = ctx.mcp_tool_to_server and tool_call_name and tool_call_name in ctx.mcp_tool_to_server sequence_number += 1 - yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone( - arguments=final_arguments, - item_id=tool_call_item_id, - output_index=len(output_messages), - sequence_number=sequence_number, - ) + if is_mcp_tool: + # Emit MCP-specific argument done event + yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDone( + arguments=final_arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) + else: + # Emit function call argument done event + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone( + arguments=final_arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=sequence_number, + ) # Convert collected chunks to complete response if chat_response_tool_calls: tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] - - # when there are tool calls, we need to clear the content - chat_response_content = [] else: tool_calls = None + # Emit content_part.done event if text content was streamed (before content gets cleared) + if content_part_emitted: + final_text = "".join(chat_response_content) + sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartDone( + response_id=response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text=final_text, + ), + sequence_number=sequence_number, + ) + + # Clear content when there are tool calls (OpenAI spec behavior) + if chat_response_tool_calls: + chat_response_content = [] + assistant_message = OpenAIAssistantMessageParam( content="".join(chat_response_content), tool_calls=tool_calls, diff --git a/tests/integration/non_ci/responses/test_responses.py b/tests/integration/non_ci/responses/test_responses.py index 776e3cf30..04266eec8 100644 --- a/tests/integration/non_ci/responses/test_responses.py +++ b/tests/integration/non_ci/responses/test_responses.py @@ -590,9 +590,17 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ # Verify tool call streaming events are present chunk_types = [chunk.type for chunk in chunks] - # Should have function call arguments delta events for tool calls - delta_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.delta"] - done_events = [chunk for chunk in chunks if chunk.type == "response.function_call_arguments.done"] + # Should have function call or MCP arguments delta/done events for tool calls + delta_events = [ + chunk + for chunk in chunks + if chunk.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"] + ] + done_events = [ + chunk + for chunk in chunks + if chunk.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"] + ] # Should have output item events for tool calls item_added_events = [chunk for chunk in chunks if chunk.type == "response.output_item.added"] @@ -606,8 +614,12 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ assert len(chunks) > 10, f"Expected rich streaming with many events, got only {len(chunks)} chunks" # Since this test involves MCP tool calls, we should see streaming events - assert len(delta_events) > 0, f"Expected function_call_arguments.delta events, got chunk types: {chunk_types}" - assert len(done_events) > 0, f"Expected function_call_arguments.done events, got chunk types: {chunk_types}" + assert len(delta_events) > 0, ( + f"Expected function_call_arguments.delta or mcp_call.arguments.delta events, got chunk types: {chunk_types}" + ) + assert len(done_events) > 0, ( + f"Expected function_call_arguments.done or mcp_call.arguments.done events, got chunk types: {chunk_types}" + ) # Should have output item events for function calls assert len(item_added_events) > 0, f"Expected response.output_item.added events, got chunk types: {chunk_types}" @@ -670,22 +682,32 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ assert isinstance(done_event.output_index, int), "Output index should be integer" assert done_event.output_index >= 0, "Output index should be non-negative" - # Group function call argument events by item_id (these should have proper tracking) - function_call_events_by_item_id = {} + # Group function call and MCP argument events by item_id (these should have proper tracking) + argument_events_by_item_id = {} for chunk in chunks: if hasattr(chunk, "item_id") and chunk.type in [ "response.function_call_arguments.delta", "response.function_call_arguments.done", + "response.mcp_call.arguments.delta", + "response.mcp_call.arguments.done", ]: item_id = chunk.item_id - if item_id not in function_call_events_by_item_id: - function_call_events_by_item_id[item_id] = [] - function_call_events_by_item_id[item_id].append(chunk) + if item_id not in argument_events_by_item_id: + argument_events_by_item_id[item_id] = [] + argument_events_by_item_id[item_id].append(chunk) - for item_id, related_events in function_call_events_by_item_id.items(): - # Should have at least one delta and one done event for a complete function call - delta_events = [e for e in related_events if e.type == "response.function_call_arguments.delta"] - done_events = [e for e in related_events if e.type == "response.function_call_arguments.done"] + for item_id, related_events in argument_events_by_item_id.items(): + # Should have at least one delta and one done event for a complete tool call + delta_events = [ + e + for e in related_events + if e.type in ["response.function_call_arguments.delta", "response.mcp_call.arguments.delta"] + ] + done_events = [ + e + for e in related_events + if e.type in ["response.function_call_arguments.done", "response.mcp_call.arguments.done"] + ] assert len(delta_events) > 0, f"Item {item_id} should have at least one delta event" assert len(done_events) == 1, f"Item {item_id} should have exactly one done event" @@ -694,6 +716,33 @@ def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_ for event in related_events: assert event.item_id == item_id, f"Event should have consistent item_id {item_id}, got {event.item_id}" + # Verify content part events if they exist (for text streaming) + content_part_added_events = [chunk for chunk in chunks if chunk.type == "response.content_part.added"] + content_part_done_events = [chunk for chunk in chunks if chunk.type == "response.content_part.done"] + + # Content part events should be paired (if any exist) + if len(content_part_added_events) > 0: + assert len(content_part_done_events) > 0, ( + "Should have content_part.done events if content_part.added events exist" + ) + + # Verify content part event structure + for added_event in content_part_added_events: + assert hasattr(added_event, "response_id"), "Content part added event should have response_id" + assert hasattr(added_event, "item_id"), "Content part added event should have item_id" + assert hasattr(added_event, "part"), "Content part added event should have part" + + # TODO: enable this after the client types are updated + # assert added_event.part.type == "output_text", "Content part should be an output_text" + + for done_event in content_part_done_events: + assert hasattr(done_event, "response_id"), "Content part done event should have response_id" + assert hasattr(done_event, "item_id"), "Content part done event should have item_id" + assert hasattr(done_event, "part"), "Content part done event should have part" + + # TODO: enable this after the client types are updated + # assert len(done_event.part.text) > 0, "Content part should have text when done" + # Basic pairing check: each output_item.added should be followed by some activity # (but we can't enforce strict 1:1 pairing due to the complexity of multi-turn scenarios) assert len(item_added_events) > 0, "Should have at least one output_item.added event" diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 855a525e9..4132a74a3 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -136,9 +136,12 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m input=input_text, model=model, temperature=0.1, + stream=True, # Enable streaming to test content part events ) - # Verify + # For streaming response, collect all chunks + chunks = [chunk async for chunk in result] + mock_inference_api.openai_chat_completion.assert_called_once_with( model=model, messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)], @@ -147,11 +150,32 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m stream=True, temperature=0.1, ) + + # Should have content part events for text streaming + # Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed + assert len(chunks) >= 4 + assert chunks[0].type == "response.created" + + # Check for content part events + content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"] + content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"] + text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"] + + assert len(content_part_added_events) >= 1, "Should have content_part.added event for text" + assert len(content_part_done_events) >= 1, "Should have content_part.done event for text" + assert len(text_delta_events) >= 1, "Should have text delta events" + + # Verify final event is completion + assert chunks[-1].type == "response.completed" + + # When streaming, the final response is in the last chunk + final_response = chunks[-1].response + assert final_response.model == model + assert len(final_response.output) == 1 + assert isinstance(final_response.output[0], OpenAIResponseMessage) + openai_responses_impl.responses_store.store_response_object.assert_called_once() - assert result.model == model - assert len(result.output) == 1 - assert isinstance(result.output[0], OpenAIResponseMessage) - assert result.output[0].content[0].text == "Dublin" + assert final_response.output[0].content[0].text == "Dublin" async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api): @@ -272,6 +296,8 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_ # Check that we got the content from our mocked tool execution result chunks = [chunk async for chunk in result] + + # Verify event types # Should have: response.created, output_item.added, function_call_arguments.delta, # function_call_arguments.done, output_item.done, response.completed assert len(chunks) == 6 From 46ff302d87562cf266d2a304f7409593ac7bb0ca Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 13 Aug 2025 18:38:34 -0700 Subject: [PATCH 24/45] chore: Remove Trendshift badge from README (#3137) ## Summary - This links to a scammy looking website with ads. ## Test plan --- README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/README.md b/README.md index 8db4580a2..4df4a5372 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,5 @@ # Llama Stack -meta-llama%2Fllama-stack | Trendshift - ------ [![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) From de692162afe0151ebac69321effd069b194d2754 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 14 Aug 2025 08:42:02 -0500 Subject: [PATCH 25/45] feat: add batches API with OpenAI compatibility (#3088) Add complete batches API implementation with protocol, providers, and tests: Core Infrastructure: - Add batches API protocol using OpenAI Batch types directly - Add Api.batches enum value and protocol mapping in resolver - Add OpenAI "batch" file purpose support - Include proper error handling (ConflictError, ResourceNotFoundError) Reference Provider: - Add ReferenceBatchesImpl with full CRUD operations (create, retrieve, cancel, list) - Implement background batch processing with configurable concurrency - Add SQLite KVStore backend for persistence - Support /v1/chat/completions endpoint with request validation Comprehensive Test Suite: - Add unit tests for provider implementation with validation - Add integration tests for end-to-end batch processing workflows - Add error handling tests for validation, malformed inputs, and edge cases Configuration: - Add max_concurrent_batches and max_concurrent_requests_per_batch options - Add provider documentation with sample configurations Test with - ``` $ uv run llama stack build --image-type venv --providers inference=YOU_PICK,files=inline::localfs,batches=inline::reference --run & $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run pytest tests/unit/providers/batches tests/integration/batches --text-model YOU_PICK ``` addresses #3066 --- docs/_static/llama-stack-spec.html | 6 +- docs/_static/llama-stack-spec.yaml | 2 + docs/source/concepts/apis.md | 1 + docs/source/providers/agents/index.md | 9 + docs/source/providers/batches/index.md | 21 + .../providers/batches/inline_reference.md | 23 + docs/source/providers/eval/index.md | 2 + docs/source/providers/inference/index.md | 6 + llama_stack/apis/batches/__init__.py | 9 + llama_stack/apis/batches/batches.py | 89 +++ llama_stack/apis/common/errors.py | 6 + llama_stack/apis/datatypes.py | 2 + llama_stack/apis/files/files.py | 1 + llama_stack/core/resolver.py | 2 + llama_stack/core/server/server.py | 5 + .../providers/inline/batches/__init__.py | 5 + .../inline/batches/reference/__init__.py | 36 + .../inline/batches/reference/batches.py | 553 +++++++++++++ .../inline/batches/reference/config.py | 40 + llama_stack/providers/registry/batches.py | 26 + scripts/provider_codegen.py | 22 + tests/integration/batches/__init__.py | 5 + tests/integration/batches/conftest.py | 122 +++ tests/integration/batches/test_batches.py | 270 +++++++ .../batches/test_batches_errors.py | 693 ++++++++++++++++ .../unit/providers/batches/test_reference.py | 753 ++++++++++++++++++ 26 files changed, 2707 insertions(+), 2 deletions(-) create mode 100644 docs/source/providers/batches/index.md create mode 100644 docs/source/providers/batches/inline_reference.md create mode 100644 llama_stack/apis/batches/__init__.py create mode 100644 llama_stack/apis/batches/batches.py create mode 100644 llama_stack/providers/inline/batches/__init__.py create mode 100644 llama_stack/providers/inline/batches/reference/__init__.py create mode 100644 llama_stack/providers/inline/batches/reference/batches.py create mode 100644 llama_stack/providers/inline/batches/reference/config.py create mode 100644 llama_stack/providers/registry/batches.py create mode 100644 tests/integration/batches/__init__.py create mode 100644 tests/integration/batches/conftest.py create mode 100644 tests/integration/batches/test_batches.py create mode 100644 tests/integration/batches/test_batches_errors.py create mode 100644 tests/unit/providers/batches/test_reference.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 0549dda21..b36626719 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -14767,7 +14767,8 @@ "OpenAIFilePurpose": { "type": "string", "enum": [ - "assistants" + "assistants", + "batch" ], "title": "OpenAIFilePurpose", "description": "Valid purpose values for OpenAI Files API." @@ -14844,7 +14845,8 @@ "purpose": { "type": "string", "enum": [ - "assistants" + "assistants", + "batch" ], "description": "The intended purpose of the file" } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index aa47cd58d..e7733b3c3 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10951,6 +10951,7 @@ components: type: string enum: - assistants + - batch title: OpenAIFilePurpose description: >- Valid purpose values for OpenAI Files API. @@ -11019,6 +11020,7 @@ components: type: string enum: - assistants + - batch description: The intended purpose of the file additionalProperties: false required: diff --git a/docs/source/concepts/apis.md b/docs/source/concepts/apis.md index 5a10d6498..f8f73a928 100644 --- a/docs/source/concepts/apis.md +++ b/docs/source/concepts/apis.md @@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle. - **Batch Inference**: run inference on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs - **Synthetic Data Generation**: generate synthetic data for model development +- **Batches**: OpenAI-compatible batch management for inference diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md index 92bf9edc0..a2c48d4b9 100644 --- a/docs/source/providers/agents/index.md +++ b/docs/source/providers/agents/index.md @@ -2,6 +2,15 @@ ## Overview +Agents API for creating and interacting with agentic systems. + + Main functionalities provided by this API: + - Create agents with specific instructions and ability to use tools. + - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". + - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). + - Agents can be provided with various shields (see the Safety API for more details). + - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. + This section contains documentation for all available providers for the **agents** API. ## Providers diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md new file mode 100644 index 000000000..2a39a626c --- /dev/null +++ b/docs/source/providers/batches/index.md @@ -0,0 +1,21 @@ +# Batches + +## Overview + +Protocol for batch processing API operations. + + The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + Note: This API is currently under active development and may undergo changes. + +This section contains documentation for all available providers for the **batches** API. + +## Providers + +```{toctree} +:maxdepth: 1 + +inline_reference +``` diff --git a/docs/source/providers/batches/inline_reference.md b/docs/source/providers/batches/inline_reference.md new file mode 100644 index 000000000..a58e5124d --- /dev/null +++ b/docs/source/providers/batches/inline_reference.md @@ -0,0 +1,23 @@ +# inline::reference + +## Description + +Reference implementation of batches API with KVStore persistence. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. | +| `max_concurrent_batches` | `` | No | 1 | Maximum number of concurrent batches to process simultaneously. | +| `max_concurrent_requests_per_batch` | `` | No | 10 | Maximum number of concurrent requests to process per batch. | + +## Sample Configuration + +```yaml +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db + +``` + diff --git a/docs/source/providers/eval/index.md b/docs/source/providers/eval/index.md index d180d256c..a14fada1d 100644 --- a/docs/source/providers/eval/index.md +++ b/docs/source/providers/eval/index.md @@ -2,6 +2,8 @@ ## Overview +Llama Stack Evaluation API for running evaluations on model and agent candidates. + This section contains documentation for all available providers for the **eval** API. ## Providers diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index 38781e5eb..b6d215474 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -2,6 +2,12 @@ ## Overview +Llama Stack Inference API for generating completions, chat completions, and embeddings. + + This API provides the raw interface to the underlying models. Two kinds of models are supported: + - LLM models: these models generate "raw" and "chat" (conversational) completions. + - Embedding models: these models generate embeddings to be used for semantic search. + This section contains documentation for all available providers for the **inference** API. ## Providers diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py new file mode 100644 index 000000000..9ce7d3d75 --- /dev/null +++ b/llama_stack/apis/batches/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .batches import Batches, BatchObject, ListBatchesResponse + +__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py new file mode 100644 index 000000000..9297d8597 --- /dev/null +++ b/llama_stack/apis/batches/batches.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Literal, Protocol, runtime_checkable + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type, webmethod + +try: + from openai.types import Batch as BatchObject +except ImportError as e: + raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e + + +@json_schema_type +class ListBatchesResponse(BaseModel): + """Response containing a list of batch objects.""" + + object: Literal["list"] = "list" + data: list[BatchObject] = Field(..., description="List of batch objects") + first_id: str | None = Field(default=None, description="ID of the first batch in the list") + last_id: str | None = Field(default=None, description="ID of the last batch in the list") + has_more: bool = Field(default=False, description="Whether there are more batches available") + + +@runtime_checkable +class Batches(Protocol): + """Protocol for batch processing API operations. + + The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + Note: This API is currently under active development and may undergo changes. + """ + + @webmethod(route="/openai/v1/batches", method="POST") + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: Literal["24h"], + metadata: dict[str, str] | None = None, + ) -> BatchObject: + """Create a new batch for processing multiple API requests. + + :param input_file_id: The ID of an uploaded file containing requests for the batch. + :param endpoint: The endpoint to be used for all requests in the batch. + :param completion_window: The time window within which the batch should be processed. + :param metadata: Optional metadata for the batch. + :returns: The created batch object. + """ + ... + + @webmethod(route="/openai/v1/batches/{batch_id}", method="GET") + async def retrieve_batch(self, batch_id: str) -> BatchObject: + """Retrieve information about a specific batch. + + :param batch_id: The ID of the batch to retrieve. + :returns: The batch object. + """ + ... + + @webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST") + async def cancel_batch(self, batch_id: str) -> BatchObject: + """Cancel a batch that is in progress. + + :param batch_id: The ID of the batch to cancel. + :returns: The updated batch object. + """ + ... + + @webmethod(route="/openai/v1/batches", method="GET") + async def list_batches( + self, + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """List all batches for the current user. + + :param after: A cursor for pagination; returns batches after this batch ID. + :param limit: Number of batches to return (default 20, max 100). + :returns: A list of batch objects. + """ + ... diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 6e0fa0b3c..7104d8db6 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -64,6 +64,12 @@ class SessionNotFoundError(ValueError): super().__init__(message) +class ConflictError(ValueError): + """raised when an operation cannot be performed due to a conflict with the current state""" + + pass + + class ModelTypeError(TypeError): """raised when a model is present but not the correct type""" diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index cabe46a2f..87fc95917 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta): :cvar inference: Text generation, chat completions, and embeddings :cvar safety: Content moderation and safety shields :cvar agents: Agent orchestration and execution + :cvar batches: Batch processing for asynchronous API requests :cvar vector_io: Vector database operations and queries :cvar datasetio: Dataset input/output operations :cvar scoring: Model output evaluation and scoring @@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta): inference = "inference" safety = "safety" agents = "agents" + batches = "batches" vector_io = "vector_io" datasetio = "datasetio" scoring = "scoring" diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index ba8701e23..a1b9dd4dc 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum): """ ASSISTANTS = "assistants" + BATCH = "batch" # TODO: Add other purposes as needed diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 70c78fb01..7ac98dac8 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -8,6 +8,7 @@ import inspect from typing import Any from llama_stack.apis.agents import Agents +from llama_stack.apis.batches import Batches from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, + Api.batches: Batches, Api.vector_io: VectorIO, Api.vector_dbs: VectorDBs, Api.models: Models, diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index e9d70fc8d..cbef8ef88 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -32,6 +32,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.cli.utils import add_config_distro_args, get_config_from_args from llama_stack.core.access_control.access_control import AccessDeniedError @@ -128,6 +129,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ] }, ) + elif isinstance(exc, ConflictError): + return HTTPException(status_code=409, detail=str(exc)) + elif isinstance(exc, ResourceNotFoundError): + return HTTPException(status_code=404, detail=str(exc)) elif isinstance(exc, ValueError): return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): diff --git a/llama_stack/providers/inline/batches/__init__.py b/llama_stack/providers/inline/batches/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/batches/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/batches/reference/__init__.py b/llama_stack/providers/inline/batches/reference/__init__.py new file mode 100644 index 000000000..a8ae92eb2 --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.files import Files +from llama_stack.apis.inference import Inference +from llama_stack.apis.models import Models +from llama_stack.core.datatypes import AccessRule, Api +from llama_stack.providers.utils.kvstore import kvstore_impl + +from .batches import ReferenceBatchesImpl +from .config import ReferenceBatchesImplConfig + +__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"] + + +async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): + kvstore = await kvstore_impl(config.kvstore) + inference_api: Inference | None = deps.get(Api.inference) + files_api: Files | None = deps.get(Api.files) + models_api: Models | None = deps.get(Api.models) + + if inference_api is None: + raise ValueError("Inference API is required but not provided in dependencies") + if files_api is None: + raise ValueError("Files API is required but not provided in dependencies") + if models_api is None: + raise ValueError("Models API is required but not provided in dependencies") + + impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py new file mode 100644 index 000000000..984ef5a90 --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -0,0 +1,553 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import itertools +import json +import time +import uuid +from io import BytesIO +from typing import Any, Literal + +from openai.types.batch import BatchError, Errors +from pydantic import BaseModel + +from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError +from llama_stack.apis.files import Files, OpenAIFilePurpose +from llama_stack.apis.inference import Inference +from llama_stack.apis.models import Models +from llama_stack.log import get_logger +from llama_stack.providers.utils.kvstore import KVStore + +from .config import ReferenceBatchesImplConfig + +BATCH_PREFIX = "batch:" + +logger = get_logger(__name__) + + +class AsyncBytesIO: + """ + Async-compatible BytesIO wrapper to allow async file-like operations. + + We use this when uploading files to the Files API, as it expects an + async file-like object. + """ + + def __init__(self, data: bytes): + self._buffer = BytesIO(data) + + async def read(self, n=-1): + return self._buffer.read(n) + + async def seek(self, pos, whence=0): + return self._buffer.seek(pos, whence) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._buffer.close() + + def __getattr__(self, name): + return getattr(self._buffer, name) + + +class BatchRequest(BaseModel): + line_num: int + custom_id: str + method: str + url: str + body: dict[str, Any] + + +class ReferenceBatchesImpl(Batches): + """Reference implementation of the Batches API. + + This implementation processes batch files by making individual requests + to the inference API and generates output files with results. + """ + + def __init__( + self, + config: ReferenceBatchesImplConfig, + inference_api: Inference, + files_api: Files, + models_api: Models, + kvstore: KVStore, + ) -> None: + self.config = config + self.kvstore = kvstore + self.inference_api = inference_api + self.files_api = files_api + self.models_api = models_api + self._processing_tasks: dict[str, asyncio.Task] = {} + self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches) + self._update_batch_lock = asyncio.Lock() + + # this is to allow tests to disable background processing + self.process_batches = True + + async def initialize(self) -> None: + # TODO: start background processing of existing tasks + pass + + async def shutdown(self) -> None: + """Shutdown the batches provider.""" + if self._processing_tasks: + # don't cancel tasks - just let them stop naturally on shutdown + # cancelling would mark batches as "cancelled" in the database + logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") + + # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: Literal["24h"], + metadata: dict[str, str] | None = None, + ) -> BatchObject: + """ + Create a new batch for processing multiple API requests. + + Error handling by levels - + 0. Input param handling, results in 40x errors before processing, e.g. + - Wrong completion_window + - Invalid metadata types + - Unknown endpoint + -> no batch created + 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. + - input_file_id missing + - invalid json in file + - missing custom_id, method, url, body + - invalid model + - streaming + -> batch created, validation sends to failed status + 2. Processing errors, result in error_file_id entries, e.g. + - Any error returned from inference endpoint + -> batch created, goes to completed status + """ + + # TODO: set expiration time for garbage collection + + if endpoint not in ["/v1/chat/completions"]: + raise ValueError( + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", + ) + + if completion_window != "24h": + raise ValueError( + f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", + ) + + batch_id = f"batch_{uuid.uuid4().hex[:16]}" + current_time = int(time.time()) + + batch = BatchObject( + id=batch_id, + object="batch", + endpoint=endpoint, + input_file_id=input_file_id, + completion_window=completion_window, + status="validating", + created_at=current_time, + metadata=metadata, + ) + + await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) + + if self.process_batches: + task = asyncio.create_task(self._process_batch(batch_id)) + self._processing_tasks[batch_id] = task + + return batch + + async def cancel_batch(self, batch_id: str) -> BatchObject: + """Cancel a batch that is in progress.""" + batch = await self.retrieve_batch(batch_id) + + if batch.status in ["cancelled", "cancelling"]: + return batch + + if batch.status in ["completed", "failed", "expired"]: + raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") + + await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) + + if batch_id in self._processing_tasks: + self._processing_tasks[batch_id].cancel() + # note: task removal and status="cancelled" handled in finally block of _process_batch + + return await self.retrieve_batch(batch_id) + + async def list_batches( + self, + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """ + List all batches, eventually only for the current user. + + With no notion of user, we return all batches. + """ + batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff") + + batches = [] + for batch_data in batch_values: + if batch_data: + batches.append(BatchObject.model_validate_json(batch_data)) + + batches.sort(key=lambda b: b.created_at, reverse=True) + + start_idx = 0 + if after: + for i, batch in enumerate(batches): + if batch.id == after: + start_idx = i + 1 + break + + page_batches = batches[start_idx : start_idx + limit] + has_more = (start_idx + limit) < len(batches) + + first_id = page_batches[0].id if page_batches else None + last_id = page_batches[-1].id if page_batches else None + + return ListBatchesResponse( + data=page_batches, + first_id=first_id, + last_id=last_id, + has_more=has_more, + ) + + async def retrieve_batch(self, batch_id: str) -> BatchObject: + """Retrieve information about a specific batch.""" + batch_data = await self.kvstore.get(f"batch:{batch_id}") + if not batch_data: + raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") + + return BatchObject.model_validate_json(batch_data) + + async def _update_batch(self, batch_id: str, **updates) -> None: + """Update batch fields in kvstore.""" + async with self._update_batch_lock: + try: + batch = await self.retrieve_batch(batch_id) + + # batch processing is async. once cancelling, only allow "cancelled" status updates + if batch.status == "cancelling" and updates.get("status") != "cancelled": + logger.info( + f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}" + ) + return + + if "errors" in updates: + updates["errors"] = updates["errors"].model_dump() + + batch_dict = batch.model_dump() + batch_dict.update(updates) + + await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict)) + except Exception as e: + logger.error(f"Failed to update batch {batch_id}: {e}") + + async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]: + """ + Read & validate input, return errors and valid input. + + Validation of + - input_file_id existance + - valid json + - custom_id, method, url, body presence and valid + - no streaming + """ + requests: list[BatchRequest] = [] + errors: list[BatchError] = [] + try: + await self.files_api.openai_retrieve_file(batch.input_file_id) + except Exception: + errors.append( + BatchError( + code="invalid_request", + line=None, + message=f"Cannot find file {batch.input_file_id}.", + param="input_file_id", + ) + ) + return errors, requests + + # TODO(SECURITY): do something about large files + file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) + file_content = file_content_response.body.decode("utf-8") + for line_num, line in enumerate(file_content.strip().split("\n"), 1): + if line.strip(): # skip empty lines + try: + request = json.loads(line) + + if not isinstance(request, dict): + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message="Each line must be a JSON dictionary object", + ) + ) + continue + + valid = True + + for param, expected_type, type_string in [ + ("custom_id", str, "string"), + ("method", str, "string"), + ("url", str, "string"), + ("body", dict, "JSON dictionary object"), + ]: + if param not in request: + errors.append( + BatchError( + code="missing_required_parameter", + line=line_num, + message=f"Missing required parameter: {param}", + param=param, + ) + ) + valid = False + elif not isinstance(request[param], expected_type): + param_name = "URL" if param == "url" else param.capitalize() + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param_name} must be a {type_string}", + param=param, + ) + ) + valid = False + + if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint: + errors.append( + BatchError( + code="invalid_url", + line=line_num, + message="URL provided for this request does not match the batch endpoint", + param="url", + ) + ) + valid = False + + if (body := request.get("body")) and isinstance(body, dict): + if body.get("stream", False): + errors.append( + BatchError( + code="streaming_unsupported", + line=line_num, + message="Streaming is not supported in batch processing", + param="body.stream", + ) + ) + valid = False + + for param, expected_type, type_string in [ + ("model", str, "a string"), + # messages is specific to /v1/chat/completions + # we could skip validating messages here and let inference fail. however, + # that would be a very expensive way to find out messages is wrong. + ("messages", list, "an array"), # TODO: allow messages to be a string? + ]: + if param not in body: + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param.capitalize()} parameter is required", + param=f"body.{param}", + ) + ) + valid = False + elif not isinstance(body[param], expected_type): + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param.capitalize()} must be {type_string}", + param=f"body.{param}", + ) + ) + valid = False + + if "model" in body and isinstance(body["model"], str): + try: + await self.models_api.get_model(body["model"]) + except Exception: + errors.append( + BatchError( + code="model_not_found", + line=line_num, + message=f"Model '{body['model']}' does not exist or is not supported", + param="body.model", + ) + ) + valid = False + + if valid: + assert isinstance(url, str), "URL must be a string" # for mypy + assert isinstance(body, dict), "Body must be a dictionary" # for mypy + requests.append( + BatchRequest( + line_num=line_num, + url=url, + method=request["method"], + custom_id=request["custom_id"], + body=body, + ), + ) + except json.JSONDecodeError: + errors.append( + BatchError( + code="invalid_json_line", + line=line_num, + message="This line is not parseable as valid JSON.", + ) + ) + + return errors, requests + + async def _process_batch(self, batch_id: str) -> None: + """Background task to process a batch of requests.""" + try: + logger.info(f"Starting batch processing for {batch_id}") + async with self._batch_semaphore: # semaphore to limit concurrency + logger.info(f"Acquired semaphore for batch {batch_id}") + await self._process_batch_impl(batch_id) + except asyncio.CancelledError: + logger.info(f"Batch processing cancelled for {batch_id}") + await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time())) + except Exception as e: + logger.error(f"Batch processing failed for {batch_id}: {e}") + await self._update_batch( + batch_id, + status="failed", + failed_at=int(time.time()), + errors=Errors(data=[BatchError(code="internal_error", message=str(e))]), + ) + finally: + self._processing_tasks.pop(batch_id, None) + + async def _process_batch_impl(self, batch_id: str) -> None: + """Implementation of batch processing logic.""" + errors: list[BatchError] = [] + batch = await self.retrieve_batch(batch_id) + + errors, requests = await self._validate_input(batch) + if errors: + await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)) + logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors") + return + + logger.info(f"Processing {len(requests)} requests for batch {batch_id}") + + total_requests = len(requests) + await self._update_batch( + batch_id, + status="in_progress", + request_counts={"total": total_requests, "completed": 0, "failed": 0}, + ) + + error_results = [] + success_results = [] + completed_count = 0 + failed_count = 0 + + for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch): + # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled + async with asyncio.TaskGroup() as tg: + chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk] + + chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True) + + for result in chunk_results: + if isinstance(result, dict) and result.get("error") is not None: # error response from inference + failed_count += 1 + error_results.append(result) + elif isinstance(result, dict) and result.get("response") is not None: # successful inference + completed_count += 1 + success_results.append(result) + else: # unexpected result + failed_count += 1 + errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}")) + + await self._update_batch( + batch_id, + request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count}, + ) + + if errors: + await self._update_batch( + batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors) + ) + return + + try: + output_file_id = await self._create_output_file(batch_id, success_results, "success") + await self._update_batch(batch_id, output_file_id=output_file_id) + + error_file_id = await self._create_output_file(batch_id, error_results, "error") + await self._update_batch(batch_id, error_file_id=error_file_id) + + await self._update_batch(batch_id, status="completed", completed_at=int(time.time())) + + logger.info( + f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed" + ) + except Exception as e: + # note: errors is empty at this point, so we don't lose anything by ignoring it + await self._update_batch( + batch_id, + status="failed", + failed_at=int(time.time()), + errors=Errors(data=[BatchError(code="output_failed", message=str(e))]), + ) + + async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict: + """Process a single request from the batch.""" + request_id = f"batch_req_{batch_id}_{request.line_num}" + + try: + # TODO(SECURITY): review body for security issues + chat_response = await self.inference_api.openai_chat_completion(**request.body) + + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": chat_response.model_dump_json(), + }, + } + except Exception as e: + logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") + return { + "id": request_id, + "custom_id": request.custom_id, + "error": {"type": "request_failed", "message": str(e)}, + } + + async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str: + """ + Create an output file with batch results. + + This function filters results based on the specified file_type + and uploads the file to the Files API. + """ + output_lines = [json.dumps(result) for result in results] + + with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer: + file_buffer.filename = f"{batch_id}_{file_type}.jsonl" + uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) + return uploaded_file.id diff --git a/llama_stack/providers/inline/batches/reference/config.py b/llama_stack/providers/inline/batches/reference/config.py new file mode 100644 index 000000000..d8d06868b --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/config.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig + + +class ReferenceBatchesImplConfig(BaseModel): + """Configuration for the Reference Batches implementation.""" + + kvstore: KVStoreConfig = Field( + description="Configuration for the key-value store backend.", + ) + + max_concurrent_batches: int = Field( + default=1, + description="Maximum number of concurrent batches to process simultaneously.", + ge=1, + ) + + max_concurrent_requests_per_batch: int = Field( + default=10, + description="Maximum number of concurrent requests to process per batch.", + ge=1, + ) + + # TODO: add a max requests per second rate limiter + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict: + return { + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="batches.db", + ), + } diff --git a/llama_stack/providers/registry/batches.py b/llama_stack/providers/registry/batches.py new file mode 100644 index 000000000..de7886efb --- /dev/null +++ b/llama_stack/providers/registry/batches.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.batches, + provider_type="inline::reference", + pip_packages=["openai"], + module="llama_stack.providers.inline.batches.reference", + config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig", + api_dependencies=[ + Api.inference, + Api.files, + Api.models, + ], + description="Reference implementation of batches API with KVStore persistence.", + ), + ] diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index 717677c52..060acfa72 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -18,6 +18,23 @@ from llama_stack.core.distribution import get_provider_registry REPO_ROOT = Path(__file__).parent.parent +def get_api_docstring(api_name: str) -> str | None: + """Extract docstring from the API protocol class.""" + try: + # Import the API module dynamically + api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()]) + + # Get the main protocol class (usually capitalized API name) + protocol_class_name = api_name.title() + if hasattr(api_module, protocol_class_name): + protocol_class = getattr(api_module, protocol_class_name) + return protocol_class.__doc__ + except (ImportError, AttributeError): + pass + + return None + + class ChangedPathTracker: """Track a list of paths we may have changed.""" @@ -261,6 +278,11 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N index_content.append(f"# {api_name.title()}\n") index_content.append("## Overview\n") + api_docstring = get_api_docstring(api_name) + if api_docstring: + cleaned_docstring = api_docstring.strip() + index_content.append(f"{cleaned_docstring}\n") + index_content.append( f"This section contains documentation for all available providers for the **{api_name}** API.\n" ) diff --git a/tests/integration/batches/__init__.py b/tests/integration/batches/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/integration/batches/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/integration/batches/conftest.py b/tests/integration/batches/conftest.py new file mode 100644 index 000000000..974fe77ab --- /dev/null +++ b/tests/integration/batches/conftest.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Shared pytest fixtures for batch tests.""" + +import json +import time +import warnings +from contextlib import contextmanager +from io import BytesIO + +import pytest + +from llama_stack.apis.files import OpenAIFilePurpose + + +class BatchHelper: + """Helper class for creating and managing batch input files.""" + + def __init__(self, client): + """Initialize with either a batch_client or openai_client.""" + self.client = client + + @contextmanager + def create_file(self, content: str | list[dict], filename_prefix="batch_input"): + """Context manager for creating and cleaning up batch input files. + + Args: + content: Either a list of batch request dictionaries or raw string content + filename_prefix: Prefix for the generated filename (or full filename if content is string) + + Yields: + The uploaded file object + """ + if isinstance(content, str): + # Handle raw string content (e.g., malformed JSONL, empty files) + file_content = content.encode("utf-8") + else: + # Handle list of batch request dictionaries + jsonl_content = "\n".join(json.dumps(req) for req in content) + file_content = jsonl_content.encode("utf-8") + + filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl" + + with BytesIO(file_content) as file_buffer: + file_buffer.name = filename + uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) + + try: + yield uploaded_file + finally: + try: + self.client.files.delete(uploaded_file.id) + except Exception: + warnings.warn( + f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}", + stacklevel=2, + ) + + def wait_for( + self, + batch_id: str, + max_wait_time: int = 60, + sleep_interval: int | None = None, + expected_statuses: set[str] | None = None, + timeout_action: str = "fail", + ): + """Wait for a batch to reach a terminal status. + + Args: + batch_id: The batch ID to monitor + max_wait_time: Maximum time to wait in seconds (default: 60 seconds) + sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s) + expected_statuses: Set of expected terminal statuses (default: {"completed"}) + timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip) + + Returns: + The final batch object + + Raises: + pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail" + pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status + """ + if sleep_interval is None: + # Default to 1/10th of max_wait_time, with min 1s and max 15s + sleep_interval = max(1, min(15, max_wait_time // 10)) + + if expected_statuses is None: + expected_statuses = {"completed"} + + terminal_statuses = {"completed", "failed", "cancelled", "expired"} + unexpected_statuses = terminal_statuses - expected_statuses + + start_time = time.time() + while time.time() - start_time < max_wait_time: + current_batch = self.client.batches.retrieve(batch_id) + + if current_batch.status in expected_statuses: + return current_batch + elif current_batch.status in unexpected_statuses: + error_msg = f"Batch reached unexpected status: {current_batch.status}" + if timeout_action == "skip": + pytest.skip(error_msg) + else: + pytest.fail(error_msg) + + time.sleep(sleep_interval) + + timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds" + if timeout_action == "skip": + pytest.skip(timeout_msg) + else: + pytest.fail(timeout_msg) + + +@pytest.fixture +def batch_helper(openai_client): + """Fixture that provides a BatchHelper instance for OpenAI client.""" + return BatchHelper(openai_client) diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py new file mode 100644 index 000000000..1ef3202d0 --- /dev/null +++ b/tests/integration/batches/test_batches.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Integration tests for the Llama Stack batch processing functionality. + +This module contains comprehensive integration tests for the batch processing API, +using the OpenAI-compatible client interface for consistency. + +Test Categories: + 1. Core Batch Operations: + - test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval + - test_batch_listing: Basic batch listing functionality + - test_batch_immediate_cancellation: Batch cancellation workflow + # TODO: cancel during processing + + 2. End-to-End Processing: + - test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation + +Note: Error conditions and edge cases are primarily tested in test_batches_errors.py +for better organization and separation of concerns. + +CLEANUP WARNING: These tests currently create batches that are not automatically +cleaned up after test completion. This may lead to resource accumulation over +multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch. +The test_batch_e2e_chat_completions test does clean up its output and error files. +""" + +import json + + +class TestBatchesIntegration: + """Integration tests for the batches API.""" + + def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id): + """Test comprehensive batch creation and retrieval scenarios.""" + test_metadata = { + "test_type": "comprehensive", + "purpose": "creation_and_retrieval_test", + "version": "1.0", + "tags": "test,batch", + } + + batch_requests = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata=test_metadata, + ) + + assert batch.endpoint == "/v1/chat/completions" + assert batch.input_file_id == uploaded_file.id + assert batch.completion_window == "24h" + assert batch.metadata == test_metadata + + retrieved_batch = openai_client.batches.retrieve(batch.id) + + assert retrieved_batch.id == batch.id + assert retrieved_batch.object == batch.object + assert retrieved_batch.endpoint == batch.endpoint + assert retrieved_batch.input_file_id == batch.input_file_id + assert retrieved_batch.completion_window == batch.completion_window + assert retrieved_batch.metadata == batch.metadata + + def test_batch_listing(self, openai_client, batch_helper, text_model_id): + """ + Test batch listing. + + This test creates multiple batches and verifies that they can be listed. + It also deletes the input files before execution, which means the batches + will appear as failed due to missing input files. This is expected and + a good thing, because it means no inference is performed. + """ + batch_ids = [] + + for i in range(2): + batch_requests = [ + { + "custom_id": f"request-{i}", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": f"Hello {i}"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + batch_ids.append(batch.id) + + batch_list = openai_client.batches.list() + + assert isinstance(batch_list.data, list) + + listed_batch_ids = {b.id for b in batch_list.data} + for batch_id in batch_ids: + assert batch_id in listed_batch_ids + + def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id): + """Test immediate batch cancellation.""" + batch_requests = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + # hopefully cancel the batch before it completes + cancelling_batch = openai_client.batches.cancel(batch.id) + assert cancelling_batch.status in ["cancelling", "cancelled"] + assert isinstance(cancelling_batch.cancelling_at, int), ( + f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}" + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min + expected_statuses={"cancelled"}, + timeout_action="skip", + ) + + assert final_batch.status == "cancelled" + assert isinstance(final_batch.cancelled_at, int), ( + f"cancelled_at should be int, got {type(final_batch.cancelled_at)}" + ) + + def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id): + """Test end-to-end batch processing for chat completions with both successful and failed operations.""" + batch_requests = [ + { + "custom_id": "success-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Say hello"}], + "max_tokens": 20, + }, + }, + { + "custom_id": "error-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "This should fail"}], + "max_tokens": -1, # Invalid negative max_tokens will cause inference error + }, + }, + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"test": "e2e_success_and_errors_test"}, + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, # often takes 2-3 minutes + expected_statuses={"completed"}, + timeout_action="skip", + ) + + # Expecting a completed batch with both successful and failed requests + # Batch(id='batch_xxx', + # completion_window='24h', + # created_at=..., + # endpoint='/v1/chat/completions', + # input_file_id='file-xxx', + # object='batch', + # status='completed', + # output_file_id='file-xxx', + # error_file_id='file-xxx', + # request_counts=BatchRequestCounts(completed=1, failed=1, total=2)) + + assert final_batch.status == "completed" + assert final_batch.request_counts is not None + assert final_batch.request_counts.total == 2 + assert final_batch.request_counts.completed == 1 + assert final_batch.request_counts.failed == 1 + + assert final_batch.output_file_id is not None, "Output file should exist for successful requests" + + output_content = openai_client.files.content(final_batch.output_file_id) + if isinstance(output_content, str): + output_text = output_content + else: + output_text = output_content.content.decode("utf-8") + + output_lines = output_text.strip().split("\n") + + for line in output_lines: + result = json.loads(line) + + assert "id" in result + assert "custom_id" in result + assert result["custom_id"] == "success-1" + + assert "response" in result + + assert result["response"]["status_code"] == 200 + assert "body" in result["response"] + assert "choices" in result["response"]["body"] + + assert final_batch.error_file_id is not None, "Error file should exist for failed requests" + + error_content = openai_client.files.content(final_batch.error_file_id) + if isinstance(error_content, str): + error_text = error_content + else: + error_text = error_content.content.decode("utf-8") + + error_lines = error_text.strip().split("\n") + + for line in error_lines: + result = json.loads(line) + + assert "id" in result + assert "custom_id" in result + assert result["custom_id"] == "error-1" + assert "error" in result + error = result["error"] + assert error is not None + assert "code" in error or "message" in error, "Error should have code or message" + + deleted_output_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully" + + deleted_error_file = openai_client.files.delete(final_batch.error_file_id) + assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" diff --git a/tests/integration/batches/test_batches_errors.py b/tests/integration/batches/test_batches_errors.py new file mode 100644 index 000000000..bc94a182e --- /dev/null +++ b/tests/integration/batches/test_batches_errors.py @@ -0,0 +1,693 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Error handling and edge case tests for the Llama Stack batch processing functionality. + +This module focuses exclusively on testing error conditions, validation failures, +and edge cases for batch operations to ensure robust error handling and graceful +degradation. + +Test Categories: + 1. File and Input Validation: + - test_batch_nonexistent_file_id: Handling invalid file IDs + - test_batch_malformed_jsonl: Processing malformed JSONL input files + - test_file_malformed_batch_file: Handling malformed files at upload time + - test_batch_missing_required_fields: Validation of required request fields + + 2. API Endpoint and Model Validation: + - test_batch_invalid_endpoint: Invalid endpoint handling during creation + - test_batch_error_handling_invalid_model: Error handling with nonexistent models + - test_batch_endpoint_mismatch: Validation of endpoint/URL consistency + + 3. Batch Lifecycle Error Handling: + - test_batch_retrieve_nonexistent: Retrieving non-existent batches + - test_batch_cancel_nonexistent: Cancelling non-existent batches + - test_batch_cancel_completed: Attempting to cancel completed batches + + 4. Parameter and Configuration Validation: + - test_batch_invalid_completion_window: Invalid completion window values + - test_batch_invalid_metadata_types: Invalid metadata type validation + - test_batch_missing_required_body_fields: Validation of required fields in request body + + 5. Feature Restriction and Compatibility: + - test_batch_streaming_not_supported: Streaming request rejection + - test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation + +Note: Core functionality and OpenAI compatibility tests are located in +test_batches_integration.py for better organization and separation of concerns. + +CLEANUP WARNING: These tests create batches to test error conditions but do not +automatically clean them up after test completion. While most error tests create +batches that fail quickly, some may create valid batches that consume resources. +""" + +import pytest +from openai import BadRequestError, ConflictError, NotFoundError + + +class TestBatchesErrorHandling: + """Error handling and edge case tests for the batches API using OpenAI client.""" + + def test_batch_nonexistent_file_id(self, openai_client, batch_helper): + """Test batch creation with nonexistent input file ID.""" + + batch = openai_client.batches.create( + input_file_id="file-nonexistent-xyz", + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError( + # code='invalid_request', + # line=None, + # message='Cannot find file ..., or organization ... does not have access to it.', + # param='file_id') + # ], object='list'), + # failed_at=1754566971, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.code == "invalid_request" + assert "cannot find file" in error.message.lower() + + def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id): + """Test batch creation with invalid endpoint.""" + batch_requests = [ + { + "custom_id": "invalid-endpoint", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + with pytest.raises(BadRequestError) as exc_info: + openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/invalid/endpoint", + completion_window="24h", + ) + + # Expected - + # Error code: 400 - { + # 'error': { + # 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.", + # 'type': 'invalid_request_error', + # 'param': 'endpoint', + # 'code': 'invalid_value' + # } + # } + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 400 + assert "invalid value" in error_msg + assert "/v1/invalid/endpoint" in error_msg + assert "supported values" in error_msg + assert "endpoint" in error_msg + assert "invalid_value" in error_msg + + def test_batch_malformed_jsonl(self, openai_client, batch_helper): + """ + Test batch with malformed JSONL input. + + The /v1/files endpoint requires valid JSONL format, so we provide a well formed line + before a malformed line to ensure we get to the /v1/batches validation stage. + """ + with batch_helper.create_file( + """{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}} +{invalid json here""", + "malformed_batch_input.jsonl", + ) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # ..., + # BatchError(code='invalid_json_line', + # line=2, + # message='This line is not parseable as valid JSON.', + # param=None) + # ], object='list'), + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) > 0 + error = final_batch.errors.data[-1] # get last error because first may be about the "test" model + assert error.code == "invalid_json_line" + assert error.line == 2 + assert "not" in error.message.lower() + assert "valid json" in error.message.lower() + + @pytest.mark.xfail(reason="Not all file providers validate content") + @pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"]) + def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests): + """Test file upload with malformed content.""" + + with pytest.raises(BadRequestError) as exc_info: + with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"): + # /v1/files rejects the file, we don't get to batch creation + pass + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 400 + assert "invalid file format" in error_msg + assert "jsonl" in error_msg + + def test_batch_retrieve_nonexistent(self, openai_client): + """Test retrieving nonexistent batch.""" + with pytest.raises(NotFoundError) as exc_info: + openai_client.batches.retrieve("batch-nonexistent-xyz") + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 404 + assert "no batch found" in error_msg or "not found" in error_msg + + def test_batch_cancel_nonexistent(self, openai_client): + """Test cancelling nonexistent batch.""" + with pytest.raises(NotFoundError) as exc_info: + openai_client.batches.cancel("batch-nonexistent-xyz") + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 404 + assert "no batch found" in error_msg or "not found" in error_msg + + def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id): + """Test cancelling already completed batch.""" + batch_requests = [ + { + "custom_id": "cancel-completed", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Quick test"}], + "max_tokens": 5, + }, + } + ] + + with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, # often take 10-11 min, give it 3 min + expected_statuses={"completed"}, + timeout_action="skip", + ) + + deleted_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully" + + with pytest.raises(ConflictError) as exc_info: + openai_client.batches.cancel(batch.id) + + # Expecting - + # Error code: 409 - { + # 'error': { + # 'message': "Cannot cancel a batch with status 'completed'.", + # 'type': 'invalid_request_error', + # 'param': None, + # 'code': None + # } + # } + # + # NOTE: Same for "failed", cancelling "cancelled" batches is allowed + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 409 + assert "cannot cancel" in error_msg + + def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id): + """Test batch with requests missing required fields.""" + batch_requests = [ + { + # Missing custom_id + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "No custom_id"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "no-method", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "No method"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "no-url", + "method": "POST", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "No URL"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "no-body", + "method": "POST", + "url": "/v1/chat/completions", + }, + ] + + with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors( + # data=[ + # BatchError( + # code='missing_required_parameter', + # line=1, + # message="Missing required parameter: 'custom_id'.", + # param='custom_id' + # ), + # BatchError( + # code='missing_required_parameter', + # line=2, + # message="Missing required parameter: 'method'.", + # param='method' + # ), + # BatchError( + # code='missing_required_parameter', + # line=3, + # message="Missing required parameter: 'url'.", + # param='url' + # ), + # BatchError( + # code='missing_required_parameter', + # line=4, + # message="Missing required parameter: 'body'.", + # param='body' + # ) + # ], object='list'), + # failed_at=1754566945, + # ...) + # ) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 4 + no_custom_id_error = final_batch.errors.data[0] + assert no_custom_id_error.code == "missing_required_parameter" + assert no_custom_id_error.line == 1 + assert "missing" in no_custom_id_error.message.lower() + assert "custom_id" in no_custom_id_error.message.lower() + no_method_error = final_batch.errors.data[1] + assert no_method_error.code == "missing_required_parameter" + assert no_method_error.line == 2 + assert "missing" in no_method_error.message.lower() + assert "method" in no_method_error.message.lower() + no_url_error = final_batch.errors.data[2] + assert no_url_error.code == "missing_required_parameter" + assert no_url_error.line == 3 + assert "missing" in no_url_error.message.lower() + assert "url" in no_url_error.message.lower() + no_body_error = final_batch.errors.data[3] + assert no_body_error.code == "missing_required_parameter" + assert no_body_error.line == 4 + assert "missing" in no_body_error.message.lower() + assert "body" in no_body_error.message.lower() + + def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id): + """Test batch creation with invalid completion window.""" + batch_requests = [ + { + "custom_id": "invalid-completion-window", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + for window in ["1h", "48h", "invalid", ""]: + with pytest.raises(BadRequestError) as exc_info: + openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window=window, + ) + assert exc_info.value.status_code == 400 + error_msg = str(exc_info.value).lower() + assert "error" in error_msg + assert "completion_window" in error_msg + + def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id): + """Test that streaming responses are not supported in batches.""" + batch_requests = [ + { + "custom_id": "streaming-test", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + "stream": True, # Not supported + }, + } + ] + + with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError(code='streaming_unsupported', + # line=1, + # message='Chat Completions: Streaming is not supported in the Batch API.', + # param='body.stream') + # ], object='list'), + # failed_at=1754566965, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.code == "streaming_unsupported" + assert error.line == 1 + assert "streaming" in error.message.lower() + assert "not supported" in error.message.lower() + assert error.param == "body.stream" + assert final_batch.failed_at is not None + + def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id): + """ + Test batch with mixed streaming and non-streaming requests. + + This is distinct from test_batch_streaming_not_supported, which tests a single + streaming request, to ensure an otherwise valid batch fails when a single + streaming request is included. + """ + batch_requests = [ + { + "custom_id": "valid-non-streaming-request", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello without streaming"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "streaming-request", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello with streaming"}], + "max_tokens": 10, + "stream": True, # Not supported + }, + }, + ] + + with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError( + # code='streaming_unsupported', + # line=2, + # message='Chat Completions: Streaming is not supported in the Batch API.', + # param='body.stream') + # ], object='list'), + # failed_at=1754574442, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.code == "streaming_unsupported" + assert error.line == 2 + assert "streaming" in error.message.lower() + assert "not supported" in error.message.lower() + assert error.param == "body.stream" + assert final_batch.failed_at is not None + + def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id): + """Test batch creation with mismatched endpoint and request URL.""" + batch_requests = [ + { + "custom_id": "endpoint-mismatch", + "method": "POST", + "url": "/v1/embeddings", # Different from batch endpoint + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + }, + } + ] + + with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", # Different from request URL + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError( + # code='invalid_url', + # line=1, + # message='The URL provided for this request does not match the batch endpoint.', + # param='url') + # ], object='list'), + # failed_at=1754566972, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.line == 1 + assert error.code == "invalid_url" + assert "does not match" in error.message.lower() + assert "endpoint" in error.message.lower() + assert final_batch.failed_at is not None + + def test_batch_error_handling_invalid_model(self, openai_client, batch_helper): + """Test batch error handling with invalid model.""" + batch_requests = [ + { + "custom_id": "invalid-model", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "nonexistent-model-xyz", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError(code='model_not_found', + # line=1, + # message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.", + # param='body.model') + # ], object='list'), + # failed_at=1754566978, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.line == 1 + assert error.code == "model_not_found" + assert "not supported" in error.message.lower() + assert error.param == "body.model" + assert final_batch.failed_at is not None + + def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id): + """Test batch with requests missing required fields in body (model and messages).""" + batch_requests = [ + { + "custom_id": "missing-model", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + # Missing model field + "messages": [{"role": "user", "content": "Hello without model"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "missing-messages", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + # Missing messages field + "max_tokens": 10, + }, + }, + ] + + with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError( + # code='invalid_request', + # line=1, + # message='Model parameter is required.', + # param='body.model'), + # BatchError( + # code='invalid_request', + # line=2, + # message='Messages parameter is required.', + # param='body.messages') + # ], object='list'), + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 2 + + model_error = final_batch.errors.data[0] + assert model_error.line == 1 + assert "model" in model_error.message.lower() + assert model_error.param == "body.model" + + messages_error = final_batch.errors.data[1] + assert messages_error.line == 2 + assert "messages" in messages_error.message.lower() + assert messages_error.param == "body.messages" + + assert final_batch.failed_at is not None + + def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id): + """Test batch creation with invalid metadata types (like lists).""" + batch_requests = [ + { + "custom_id": "invalid-metadata-type", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + with pytest.raises(Exception) as exc_info: + openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "tags": ["tag1", "tag2"], # Invalid type, should be a string + }, + ) + + # Expecting - + # Error code: 400 - {'error': + # {'message': "Invalid type for 'metadata.tags': expected a string, + # but got an array instead.", + # 'type': 'invalid_request_error', 'param': 'metadata.tags', + # 'code': 'invalid_type'}} + + error_msg = str(exc_info.value).lower() + assert "400" in error_msg + assert "tags" in error_msg + assert "string" in error_msg diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py new file mode 100644 index 000000000..9fe0cc710 --- /dev/null +++ b/tests/unit/providers/batches/test_reference.py @@ -0,0 +1,753 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Test suite for the reference implementation of the Batches API. + +The tests are categorized and outlined below, keep this updated: + +- Batch creation with various parameters and validation: + * test_create_and_retrieve_batch_success (positive) + * test_create_batch_without_metadata (positive) + * test_create_batch_completion_window (negative) + * test_create_batch_invalid_endpoints (negative) + * test_create_batch_invalid_metadata (negative) + +- Batch retrieval and error handling for non-existent batches: + * test_retrieve_batch_not_found (negative) + +- Batch cancellation with proper status transitions: + * test_cancel_batch_success (positive) + * test_cancel_batch_invalid_statuses (negative) + * test_cancel_batch_not_found (negative) + +- Batch listing with pagination and filtering: + * test_list_batches_empty (positive) + * test_list_batches_single_batch (positive) + * test_list_batches_multiple_batches (positive) + * test_list_batches_with_limit (positive) + * test_list_batches_with_pagination (positive) + * test_list_batches_invalid_after (negative) + +- Data persistence in the underlying key-value store: + * test_kvstore_persistence (positive) + +- Batch processing concurrency control: + * test_max_concurrent_batches (positive) + +- Input validation testing (direct _validate_input method tests): + * test_validate_input_file_not_found (negative) + * test_validate_input_file_exists_empty_content (positive) + * test_validate_input_file_mixed_valid_invalid_json (mixed) + * test_validate_input_invalid_model (negative) + * test_validate_input_url_mismatch (negative) + * test_validate_input_multiple_errors_per_request (negative) + * test_validate_input_invalid_request_format (negative) + * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) + * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) + +The tests use temporary SQLite databases for isolation and mock external +dependencies like inference, files, and models APIs. +""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from llama_stack.apis.batches import BatchObject +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError +from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl +from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +class TestReferenceBatchesImpl: + """Test the reference implementation of the Batches API.""" + + @pytest.fixture + async def provider(self): + """Create a test provider instance with temporary database.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test_batches.db" + kvstore_config = SqliteKVStoreConfig(db_path=str(db_path)) + config = ReferenceBatchesImplConfig(kvstore=kvstore_config) + + # Create kvstore and mock APIs + from unittest.mock import AsyncMock + + from llama_stack.providers.utils.kvstore import kvstore_impl + + kvstore = await kvstore_impl(config.kvstore) + mock_inference = AsyncMock() + mock_files = AsyncMock() + mock_models = AsyncMock() + + provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore) + await provider.initialize() + + # unit tests should not require background processing + provider.process_batches = False + + yield provider + + await provider.shutdown() + + @pytest.fixture + def sample_batch_data(self): + """Sample batch data for testing.""" + return { + "input_file_id": "file_abc123", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "metadata": {"test": "true", "priority": "high"}, + } + + def _validate_batch_type(self, batch, expected_metadata=None): + """ + Helper function to validate batch object structure and field types. + + Note: This validates the direct BatchObject from the provider, not the + client library response which has a different structure. + + Args: + batch: The BatchObject instance to validate. + expected_metadata: Optional expected metadata dictionary to validate against. + """ + assert isinstance(batch.id, str) + assert isinstance(batch.completion_window, str) + assert isinstance(batch.created_at, int) + assert isinstance(batch.endpoint, str) + assert isinstance(batch.input_file_id, str) + assert batch.object == "batch" + assert batch.status in [ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled", + ] + + if expected_metadata is not None: + assert batch.metadata == expected_metadata + + timestamp_fields = [ + "cancelled_at", + "cancelling_at", + "completed_at", + "expired_at", + "expires_at", + "failed_at", + "finalizing_at", + "in_progress_at", + ] + for field in timestamp_fields: + field_value = getattr(batch, field, None) + if field_value is not None: + assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}" + + file_id_fields = ["error_file_id", "output_file_id"] + for field in file_id_fields: + field_value = getattr(batch, field, None) + if field_value is not None: + assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}" + + if hasattr(batch, "request_counts") and batch.request_counts is not None: + assert isinstance(batch.request_counts.completed, int), ( + f"request_counts.completed should be int, got {type(batch.request_counts.completed)}" + ) + assert isinstance(batch.request_counts.failed, int), ( + f"request_counts.failed should be int, got {type(batch.request_counts.failed)}" + ) + assert isinstance(batch.request_counts.total, int), ( + f"request_counts.total should be int, got {type(batch.request_counts.total)}" + ) + + if hasattr(batch, "errors") and batch.errors is not None: + assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}" + + if hasattr(batch.errors, "data") and batch.errors.data is not None: + assert isinstance(batch.errors.data, list), ( + f"errors.data should be list or None, got {type(batch.errors.data)}" + ) + + for i, error_item in enumerate(batch.errors.data): + assert isinstance(error_item, dict), ( + f"errors.data[{i}] should be object or dict, got {type(error_item)}" + ) + + if hasattr(error_item, "code") and error_item.code is not None: + assert isinstance(error_item.code, str), ( + f"errors.data[{i}].code should be str or None, got {type(error_item.code)}" + ) + + if hasattr(error_item, "line") and error_item.line is not None: + assert isinstance(error_item.line, int), ( + f"errors.data[{i}].line should be int or None, got {type(error_item.line)}" + ) + + if hasattr(error_item, "message") and error_item.message is not None: + assert isinstance(error_item.message, str), ( + f"errors.data[{i}].message should be str or None, got {type(error_item.message)}" + ) + + if hasattr(error_item, "param") and error_item.param is not None: + assert isinstance(error_item.param, str), ( + f"errors.data[{i}].param should be str or None, got {type(error_item.param)}" + ) + + if hasattr(batch.errors, "object") and batch.errors.object is not None: + assert isinstance(batch.errors.object, str), ( + f"errors.object should be str or None, got {type(batch.errors.object)}" + ) + assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}" + + async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data): + """Test successful batch creation and retrieval.""" + created_batch = await provider.create_batch(**sample_batch_data) + + self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"]) + + assert created_batch.id.startswith("batch_") + assert len(created_batch.id) > 13 + assert created_batch.object == "batch" + assert created_batch.endpoint == sample_batch_data["endpoint"] + assert created_batch.input_file_id == sample_batch_data["input_file_id"] + assert created_batch.completion_window == sample_batch_data["completion_window"] + assert created_batch.status == "validating" + assert created_batch.metadata == sample_batch_data["metadata"] + assert isinstance(created_batch.created_at, int) + assert created_batch.created_at > 0 + + retrieved_batch = await provider.retrieve_batch(created_batch.id) + + self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"]) + + assert retrieved_batch.id == created_batch.id + assert retrieved_batch.input_file_id == created_batch.input_file_id + assert retrieved_batch.endpoint == created_batch.endpoint + assert retrieved_batch.status == created_batch.status + assert retrieved_batch.metadata == created_batch.metadata + + async def test_create_batch_without_metadata(self, provider): + """Test batch creation without optional metadata.""" + batch = await provider.create_batch( + input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h" + ) + + assert batch.metadata is None + + async def test_create_batch_completion_window(self, provider): + """Test batch creation with invalid completion window.""" + with pytest.raises(ValueError, match="Invalid completion_window"): + await provider.create_batch( + input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now" + ) + + @pytest.mark.parametrize( + "endpoint", + [ + "/v1/embeddings", + "/v1/completions", + "/v1/invalid/endpoint", + "", + ], + ) + async def test_create_batch_invalid_endpoints(self, provider, endpoint): + """Test batch creation with various invalid endpoints.""" + with pytest.raises(ValueError, match="Invalid endpoint"): + await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h") + + async def test_create_batch_invalid_metadata(self, provider): + """Test that batch creation fails with invalid metadata.""" + with pytest.raises(ValueError, match="should be a valid string"): + await provider.create_batch( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={123: "invalid_key"}, # Non-string key + ) + + with pytest.raises(ValueError, match="should be a valid string"): + await provider.create_batch( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"valid_key": 456}, # Non-string value + ) + + async def test_retrieve_batch_not_found(self, provider): + """Test error when retrieving non-existent batch.""" + with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): + await provider.retrieve_batch("nonexistent_batch") + + async def test_cancel_batch_success(self, provider, sample_batch_data): + """Test successful batch cancellation.""" + created_batch = await provider.create_batch(**sample_batch_data) + assert created_batch.status == "validating" + + cancelled_batch = await provider.cancel_batch(created_batch.id) + + assert cancelled_batch.id == created_batch.id + assert cancelled_batch.status in ["cancelling", "cancelled"] + assert isinstance(cancelled_batch.cancelling_at, int) + assert cancelled_batch.cancelling_at >= created_batch.created_at + + @pytest.mark.parametrize("status", ["failed", "expired", "completed"]) + async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status): + """Test error when cancelling batch in final states.""" + provider.process_batches = False + created_batch = await provider.create_batch(**sample_batch_data) + + # directly update status in kvstore + await provider._update_batch(created_batch.id, status=status) + + with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"): + await provider.cancel_batch(created_batch.id) + + async def test_cancel_batch_not_found(self, provider): + """Test error when cancelling non-existent batch.""" + with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): + await provider.cancel_batch("nonexistent_batch") + + async def test_list_batches_empty(self, provider): + """Test listing batches when none exist.""" + response = await provider.list_batches() + + assert response.object == "list" + assert response.data == [] + assert response.first_id is None + assert response.last_id is None + assert response.has_more is False + + async def test_list_batches_single_batch(self, provider, sample_batch_data): + """Test listing batches with single batch.""" + created_batch = await provider.create_batch(**sample_batch_data) + + response = await provider.list_batches() + + assert len(response.data) == 1 + self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"]) + assert response.data[0].id == created_batch.id + assert response.first_id == created_batch.id + assert response.last_id == created_batch.id + assert response.has_more is False + + async def test_list_batches_multiple_batches(self, provider): + """Test listing multiple batches.""" + batches = [ + await provider.create_batch( + input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + ) + for i in range(3) + ] + + response = await provider.list_batches() + + assert len(response.data) == 3 + + batch_ids = {batch.id for batch in response.data} + expected_ids = {batch.id for batch in batches} + assert batch_ids == expected_ids + assert response.has_more is False + + assert response.first_id in expected_ids + assert response.last_id in expected_ids + + async def test_list_batches_with_limit(self, provider): + """Test listing batches with limit parameter.""" + batches = [ + await provider.create_batch( + input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + ) + for i in range(3) + ] + + response = await provider.list_batches(limit=2) + + assert len(response.data) == 2 + assert response.has_more is True + assert response.first_id == response.data[0].id + assert response.last_id == response.data[1].id + batch_ids = {batch.id for batch in response.data} + expected_ids = {batch.id for batch in batches} + assert batch_ids.issubset(expected_ids) + + async def test_list_batches_with_pagination(self, provider): + """Test listing batches with pagination using 'after' parameter.""" + for i in range(3): + await provider.create_batch( + input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + ) + + # Get first page + first_page = await provider.list_batches(limit=1) + assert len(first_page.data) == 1 + assert first_page.has_more is True + + # Get second page using 'after' + second_page = await provider.list_batches(limit=1, after=first_page.data[0].id) + assert len(second_page.data) == 1 + assert second_page.data[0].id != first_page.data[0].id + + # Verify we got the next batch in order + all_batches = await provider.list_batches() + expected_second_batch_id = all_batches.data[1].id + assert second_page.data[0].id == expected_second_batch_id + + async def test_list_batches_invalid_after(self, provider, sample_batch_data): + """Test listing batches with invalid 'after' parameter.""" + await provider.create_batch(**sample_batch_data) + + response = await provider.list_batches(after="nonexistent_batch") + + # Should return all batches (no filtering when 'after' batch not found) + assert len(response.data) == 1 + + async def test_kvstore_persistence(self, provider, sample_batch_data): + """Test that batches are properly persisted in kvstore.""" + batch = await provider.create_batch(**sample_batch_data) + + stored_data = await provider.kvstore.get(f"batch:{batch.id}") + assert stored_data is not None + + stored_batch_dict = json.loads(stored_data) + assert stored_batch_dict["id"] == batch.id + assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"] + + async def test_validate_input_file_not_found(self, provider): + """Test _validate_input when input file does not exist.""" + provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found")) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="nonexistent_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + assert errors[0].code == "invalid_request" + assert errors[0].message == "Cannot find file nonexistent_file." + assert errors[0].param == "input_file_id" + assert errors[0].line is None + + async def test_validate_input_file_exists_empty_content(self, provider): + """Test _validate_input when file exists but is empty.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + mock_response.body = b"" + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="empty_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 0 + assert len(requests) == 0 + + async def test_validate_input_file_mixed_valid_invalid_json(self, provider): + """Test _validate_input when file contains valid and invalid JSON lines.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + # Line 1: valid JSON with proper body args, Line 2: invalid JSON + mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json' + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="mixed_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + # Should have 1 JSON parsing error from line 2, and 1 valid request from line 1 + assert len(errors) == 1 + assert len(requests) == 1 + + assert errors[0].code == "invalid_json_line" + assert errors[0].line == 2 + assert errors[0].message == "This line is not parseable as valid JSON." + + assert requests[0].custom_id == "req-1" + assert requests[0].method == "POST" + assert requests[0].url == "/v1/chat/completions" + assert requests[0].body["model"] == "test-model" + assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}] + + async def test_validate_input_invalid_model(self, provider): + """Test _validate_input when file contains request with non-existent model.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}' + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found")) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="invalid_model_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == "model_not_found" + assert errors[0].line == 1 + assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported" + assert errors[0].param == "body.model" + + @pytest.mark.parametrize( + "param_name,param_path,error_code,error_message", + [ + ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"), + ("method", "method", "missing_required_parameter", "Missing required parameter: method"), + ("url", "url", "missing_required_parameter", "Missing required parameter: url"), + ("body", "body", "missing_required_parameter", "Missing required parameter: body"), + ("model", "body.model", "invalid_request", "Model parameter is required"), + ("messages", "body.messages", "invalid_request", "Messages parameter is required"), + ], + ) + async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): + """Test _validate_input when file contains request with missing required parameters.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + + base_request = { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}, + } + + # Remove the specific parameter being tested + if "." in param_path: + top_level, nested_param = param_path.split(".", 1) + del base_request[top_level][nested_param] + else: + del base_request[param_name] + + mock_response.body = json.dumps(base_request).encode() + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id=f"missing_{param_name}_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == error_code + assert errors[0].line == 1 + assert errors[0].message == error_message + assert errors[0].param == param_path + + async def test_validate_input_url_mismatch(self, provider): + """Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}' + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", # This doesn't match the URL in the request + input_file_id="url_mismatch_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == "invalid_url" + assert errors[0].line == 1 + assert errors[0].message == "URL provided for this request does not match the batch endpoint" + assert errors[0].param == "url" + + async def test_validate_input_multiple_errors_per_request(self, provider): + """Test _validate_input when a single request has multiple validation errors.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + # Request missing custom_id, has invalid URL, and missing model in body + mock_response.body = ( + b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}' + ) + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request + input_file_id="multiple_errors_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) >= 2 # At least missing custom_id and URL mismatch + assert len(requests) == 0 + + for error in errors: + assert error.line == 1 + + error_codes = {error.code for error in errors} + assert "missing_required_parameter" in error_codes # missing custom_id + assert "invalid_url" in error_codes # URL mismatch + + async def test_validate_input_invalid_request_format(self, provider): + """Test _validate_input when file contains non-object JSON (array, string, number).""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + mock_response.body = b'["not", "a", "request", "object"]' + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="invalid_format_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == "invalid_request" + assert errors[0].line == 1 + assert errors[0].message == "Each line must be a JSON dictionary object" + + @pytest.mark.parametrize( + "param_name,param_path,invalid_value,error_message", + [ + ("custom_id", "custom_id", 12345, "Custom_id must be a string"), + ("url", "url", 123, "URL must be a string"), + ("method", "method", ["POST"], "Method must be a string"), + ("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"), + ("model", "body.model", 123, "Model must be a string"), + ("messages", "body.messages", "invalid messages format", "Messages must be an array"), + ], + ) + async def test_validate_input_invalid_parameter_types( + self, provider, param_name, param_path, invalid_value, error_message + ): + """Test _validate_input when file contains request with parameters that have invalid types.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + + base_request = { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}, + } + + # Override the specific parameter with invalid value + if "." in param_path: + top_level, nested_param = param_path.split(".", 1) + base_request[top_level][nested_param] = invalid_value + else: + base_request[param_name] = invalid_value + + mock_response.body = json.dumps(base_request).encode() + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id=f"invalid_{param_name}_type_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == "invalid_request" + assert errors[0].line == 1 + assert errors[0].message == error_message + assert errors[0].param == param_path + + async def test_max_concurrent_batches(self, provider): + """Test max_concurrent_batches configuration and concurrency control.""" + import asyncio + + provider._batch_semaphore = asyncio.Semaphore(2) + + provider.process_batches = True # enable because we're testing background processing + + active_batches = 0 + + async def add_and_wait(batch_id: str): + nonlocal active_batches + active_batches += 1 + await asyncio.sleep(float("inf")) + + # the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl, + # so we can replace _process_batch_impl with our mock to control concurrency + provider._process_batch_impl = add_and_wait + + for _ in range(3): + await provider.create_batch( + input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h" + ) + + await asyncio.sleep(0.042) # let tasks start + + assert active_batches == 2, f"Expected 2 active batches, got {active_batches}" From ee7631b6cf23793b3921645b896fef45c10aaea7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 14 Aug 2025 10:08:54 -0700 Subject: [PATCH 26/45] Revert "feat: add batches API with OpenAI compatibility" (#3149) Reverts llamastack/llama-stack#3088 The PR broke integration tests. --- docs/_static/llama-stack-spec.html | 6 +- docs/_static/llama-stack-spec.yaml | 2 - docs/source/concepts/apis.md | 1 - docs/source/providers/agents/index.md | 9 - docs/source/providers/batches/index.md | 21 - .../providers/batches/inline_reference.md | 23 - docs/source/providers/eval/index.md | 2 - docs/source/providers/inference/index.md | 6 - llama_stack/apis/batches/__init__.py | 9 - llama_stack/apis/batches/batches.py | 89 --- llama_stack/apis/common/errors.py | 6 - llama_stack/apis/datatypes.py | 2 - llama_stack/apis/files/files.py | 1 - llama_stack/core/resolver.py | 2 - llama_stack/core/server/server.py | 5 - .../providers/inline/batches/__init__.py | 5 - .../inline/batches/reference/__init__.py | 36 - .../inline/batches/reference/batches.py | 553 ------------- .../inline/batches/reference/config.py | 40 - llama_stack/providers/registry/batches.py | 26 - scripts/provider_codegen.py | 22 - tests/integration/batches/__init__.py | 5 - tests/integration/batches/conftest.py | 122 --- tests/integration/batches/test_batches.py | 270 ------- .../batches/test_batches_errors.py | 693 ---------------- .../unit/providers/batches/test_reference.py | 753 ------------------ 26 files changed, 2 insertions(+), 2707 deletions(-) delete mode 100644 docs/source/providers/batches/index.md delete mode 100644 docs/source/providers/batches/inline_reference.md delete mode 100644 llama_stack/apis/batches/__init__.py delete mode 100644 llama_stack/apis/batches/batches.py delete mode 100644 llama_stack/providers/inline/batches/__init__.py delete mode 100644 llama_stack/providers/inline/batches/reference/__init__.py delete mode 100644 llama_stack/providers/inline/batches/reference/batches.py delete mode 100644 llama_stack/providers/inline/batches/reference/config.py delete mode 100644 llama_stack/providers/registry/batches.py delete mode 100644 tests/integration/batches/__init__.py delete mode 100644 tests/integration/batches/conftest.py delete mode 100644 tests/integration/batches/test_batches.py delete mode 100644 tests/integration/batches/test_batches_errors.py delete mode 100644 tests/unit/providers/batches/test_reference.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b36626719..0549dda21 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -14767,8 +14767,7 @@ "OpenAIFilePurpose": { "type": "string", "enum": [ - "assistants", - "batch" + "assistants" ], "title": "OpenAIFilePurpose", "description": "Valid purpose values for OpenAI Files API." @@ -14845,8 +14844,7 @@ "purpose": { "type": "string", "enum": [ - "assistants", - "batch" + "assistants" ], "description": "The intended purpose of the file" } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index e7733b3c3..aa47cd58d 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10951,7 +10951,6 @@ components: type: string enum: - assistants - - batch title: OpenAIFilePurpose description: >- Valid purpose values for OpenAI Files API. @@ -11020,7 +11019,6 @@ components: type: string enum: - assistants - - batch description: The intended purpose of the file additionalProperties: false required: diff --git a/docs/source/concepts/apis.md b/docs/source/concepts/apis.md index f8f73a928..5a10d6498 100644 --- a/docs/source/concepts/apis.md +++ b/docs/source/concepts/apis.md @@ -18,4 +18,3 @@ We are working on adding a few more APIs to complete the application lifecycle. - **Batch Inference**: run inference on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs - **Synthetic Data Generation**: generate synthetic data for model development -- **Batches**: OpenAI-compatible batch management for inference diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md index a2c48d4b9..92bf9edc0 100644 --- a/docs/source/providers/agents/index.md +++ b/docs/source/providers/agents/index.md @@ -2,15 +2,6 @@ ## Overview -Agents API for creating and interacting with agentic systems. - - Main functionalities provided by this API: - - Create agents with specific instructions and ability to use tools. - - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". - - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). - - Agents can be provided with various shields (see the Safety API for more details). - - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. - This section contains documentation for all available providers for the **agents** API. ## Providers diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md deleted file mode 100644 index 2a39a626c..000000000 --- a/docs/source/providers/batches/index.md +++ /dev/null @@ -1,21 +0,0 @@ -# Batches - -## Overview - -Protocol for batch processing API operations. - - The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. - - Note: This API is currently under active development and may undergo changes. - -This section contains documentation for all available providers for the **batches** API. - -## Providers - -```{toctree} -:maxdepth: 1 - -inline_reference -``` diff --git a/docs/source/providers/batches/inline_reference.md b/docs/source/providers/batches/inline_reference.md deleted file mode 100644 index a58e5124d..000000000 --- a/docs/source/providers/batches/inline_reference.md +++ /dev/null @@ -1,23 +0,0 @@ -# inline::reference - -## Description - -Reference implementation of batches API with KVStore persistence. - -## Configuration - -| Field | Type | Required | Default | Description | -|-------|------|----------|---------|-------------| -| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. | -| `max_concurrent_batches` | `` | No | 1 | Maximum number of concurrent batches to process simultaneously. | -| `max_concurrent_requests_per_batch` | `` | No | 10 | Maximum number of concurrent requests to process per batch. | - -## Sample Configuration - -```yaml -kvstore: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db - -``` - diff --git a/docs/source/providers/eval/index.md b/docs/source/providers/eval/index.md index a14fada1d..d180d256c 100644 --- a/docs/source/providers/eval/index.md +++ b/docs/source/providers/eval/index.md @@ -2,8 +2,6 @@ ## Overview -Llama Stack Evaluation API for running evaluations on model and agent candidates. - This section contains documentation for all available providers for the **eval** API. ## Providers diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index b6d215474..38781e5eb 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -2,12 +2,6 @@ ## Overview -Llama Stack Inference API for generating completions, chat completions, and embeddings. - - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. - This section contains documentation for all available providers for the **inference** API. ## Providers diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py deleted file mode 100644 index 9ce7d3d75..000000000 --- a/llama_stack/apis/batches/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .batches import Batches, BatchObject, ListBatchesResponse - -__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py deleted file mode 100644 index 9297d8597..000000000 --- a/llama_stack/apis/batches/batches.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Literal, Protocol, runtime_checkable - -from pydantic import BaseModel, Field - -from llama_stack.schema_utils import json_schema_type, webmethod - -try: - from openai.types import Batch as BatchObject -except ImportError as e: - raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e - - -@json_schema_type -class ListBatchesResponse(BaseModel): - """Response containing a list of batch objects.""" - - object: Literal["list"] = "list" - data: list[BatchObject] = Field(..., description="List of batch objects") - first_id: str | None = Field(default=None, description="ID of the first batch in the list") - last_id: str | None = Field(default=None, description="ID of the last batch in the list") - has_more: bool = Field(default=False, description="Whether there are more batches available") - - -@runtime_checkable -class Batches(Protocol): - """Protocol for batch processing API operations. - - The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. - - Note: This API is currently under active development and may undergo changes. - """ - - @webmethod(route="/openai/v1/batches", method="POST") - async def create_batch( - self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - ) -> BatchObject: - """Create a new batch for processing multiple API requests. - - :param input_file_id: The ID of an uploaded file containing requests for the batch. - :param endpoint: The endpoint to be used for all requests in the batch. - :param completion_window: The time window within which the batch should be processed. - :param metadata: Optional metadata for the batch. - :returns: The created batch object. - """ - ... - - @webmethod(route="/openai/v1/batches/{batch_id}", method="GET") - async def retrieve_batch(self, batch_id: str) -> BatchObject: - """Retrieve information about a specific batch. - - :param batch_id: The ID of the batch to retrieve. - :returns: The batch object. - """ - ... - - @webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST") - async def cancel_batch(self, batch_id: str) -> BatchObject: - """Cancel a batch that is in progress. - - :param batch_id: The ID of the batch to cancel. - :returns: The updated batch object. - """ - ... - - @webmethod(route="/openai/v1/batches", method="GET") - async def list_batches( - self, - after: str | None = None, - limit: int = 20, - ) -> ListBatchesResponse: - """List all batches for the current user. - - :param after: A cursor for pagination; returns batches after this batch ID. - :param limit: Number of batches to return (default 20, max 100). - :returns: A list of batch objects. - """ - ... diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 7104d8db6..6e0fa0b3c 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -64,12 +64,6 @@ class SessionNotFoundError(ValueError): super().__init__(message) -class ConflictError(ValueError): - """raised when an operation cannot be performed due to a conflict with the current state""" - - pass - - class ModelTypeError(TypeError): """raised when a model is present but not the correct type""" diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 87fc95917..cabe46a2f 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -86,7 +86,6 @@ class Api(Enum, metaclass=DynamicApiMeta): :cvar inference: Text generation, chat completions, and embeddings :cvar safety: Content moderation and safety shields :cvar agents: Agent orchestration and execution - :cvar batches: Batch processing for asynchronous API requests :cvar vector_io: Vector database operations and queries :cvar datasetio: Dataset input/output operations :cvar scoring: Model output evaluation and scoring @@ -109,7 +108,6 @@ class Api(Enum, metaclass=DynamicApiMeta): inference = "inference" safety = "safety" agents = "agents" - batches = "batches" vector_io = "vector_io" datasetio = "datasetio" scoring = "scoring" diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index a1b9dd4dc..ba8701e23 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -22,7 +22,6 @@ class OpenAIFilePurpose(StrEnum): """ ASSISTANTS = "assistants" - BATCH = "batch" # TODO: Add other purposes as needed diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 7ac98dac8..70c78fb01 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -8,7 +8,6 @@ import inspect from typing import Any from llama_stack.apis.agents import Agents -from llama_stack.apis.batches import Batches from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -76,7 +75,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, - Api.batches: Batches, Api.vector_io: VectorIO, Api.vector_dbs: VectorDBs, Api.models: Models, diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index cbef8ef88..e9d70fc8d 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -32,7 +32,6 @@ from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError -from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.cli.utils import add_config_distro_args, get_config_from_args from llama_stack.core.access_control.access_control import AccessDeniedError @@ -129,10 +128,6 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ] }, ) - elif isinstance(exc, ConflictError): - return HTTPException(status_code=409, detail=str(exc)) - elif isinstance(exc, ResourceNotFoundError): - return HTTPException(status_code=404, detail=str(exc)) elif isinstance(exc, ValueError): return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): diff --git a/llama_stack/providers/inline/batches/__init__.py b/llama_stack/providers/inline/batches/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/batches/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/batches/reference/__init__.py b/llama_stack/providers/inline/batches/reference/__init__.py deleted file mode 100644 index a8ae92eb2..000000000 --- a/llama_stack/providers/inline/batches/reference/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.files import Files -from llama_stack.apis.inference import Inference -from llama_stack.apis.models import Models -from llama_stack.core.datatypes import AccessRule, Api -from llama_stack.providers.utils.kvstore import kvstore_impl - -from .batches import ReferenceBatchesImpl -from .config import ReferenceBatchesImplConfig - -__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"] - - -async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): - kvstore = await kvstore_impl(config.kvstore) - inference_api: Inference | None = deps.get(Api.inference) - files_api: Files | None = deps.get(Api.files) - models_api: Models | None = deps.get(Api.models) - - if inference_api is None: - raise ValueError("Inference API is required but not provided in dependencies") - if files_api is None: - raise ValueError("Files API is required but not provided in dependencies") - if models_api is None: - raise ValueError("Models API is required but not provided in dependencies") - - impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py deleted file mode 100644 index 984ef5a90..000000000 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ /dev/null @@ -1,553 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import itertools -import json -import time -import uuid -from io import BytesIO -from typing import Any, Literal - -from openai.types.batch import BatchError, Errors -from pydantic import BaseModel - -from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse -from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError -from llama_stack.apis.files import Files, OpenAIFilePurpose -from llama_stack.apis.inference import Inference -from llama_stack.apis.models import Models -from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import KVStore - -from .config import ReferenceBatchesImplConfig - -BATCH_PREFIX = "batch:" - -logger = get_logger(__name__) - - -class AsyncBytesIO: - """ - Async-compatible BytesIO wrapper to allow async file-like operations. - - We use this when uploading files to the Files API, as it expects an - async file-like object. - """ - - def __init__(self, data: bytes): - self._buffer = BytesIO(data) - - async def read(self, n=-1): - return self._buffer.read(n) - - async def seek(self, pos, whence=0): - return self._buffer.seek(pos, whence) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._buffer.close() - - def __getattr__(self, name): - return getattr(self._buffer, name) - - -class BatchRequest(BaseModel): - line_num: int - custom_id: str - method: str - url: str - body: dict[str, Any] - - -class ReferenceBatchesImpl(Batches): - """Reference implementation of the Batches API. - - This implementation processes batch files by making individual requests - to the inference API and generates output files with results. - """ - - def __init__( - self, - config: ReferenceBatchesImplConfig, - inference_api: Inference, - files_api: Files, - models_api: Models, - kvstore: KVStore, - ) -> None: - self.config = config - self.kvstore = kvstore - self.inference_api = inference_api - self.files_api = files_api - self.models_api = models_api - self._processing_tasks: dict[str, asyncio.Task] = {} - self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches) - self._update_batch_lock = asyncio.Lock() - - # this is to allow tests to disable background processing - self.process_batches = True - - async def initialize(self) -> None: - # TODO: start background processing of existing tasks - pass - - async def shutdown(self) -> None: - """Shutdown the batches provider.""" - if self._processing_tasks: - # don't cancel tasks - just let them stop naturally on shutdown - # cancelling would mark batches as "cancelled" in the database - logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") - - # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions - async def create_batch( - self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - ) -> BatchObject: - """ - Create a new batch for processing multiple API requests. - - Error handling by levels - - 0. Input param handling, results in 40x errors before processing, e.g. - - Wrong completion_window - - Invalid metadata types - - Unknown endpoint - -> no batch created - 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. - - input_file_id missing - - invalid json in file - - missing custom_id, method, url, body - - invalid model - - streaming - -> batch created, validation sends to failed status - 2. Processing errors, result in error_file_id entries, e.g. - - Any error returned from inference endpoint - -> batch created, goes to completed status - """ - - # TODO: set expiration time for garbage collection - - if endpoint not in ["/v1/chat/completions"]: - raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", - ) - - if completion_window != "24h": - raise ValueError( - f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", - ) - - batch_id = f"batch_{uuid.uuid4().hex[:16]}" - current_time = int(time.time()) - - batch = BatchObject( - id=batch_id, - object="batch", - endpoint=endpoint, - input_file_id=input_file_id, - completion_window=completion_window, - status="validating", - created_at=current_time, - metadata=metadata, - ) - - await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) - - if self.process_batches: - task = asyncio.create_task(self._process_batch(batch_id)) - self._processing_tasks[batch_id] = task - - return batch - - async def cancel_batch(self, batch_id: str) -> BatchObject: - """Cancel a batch that is in progress.""" - batch = await self.retrieve_batch(batch_id) - - if batch.status in ["cancelled", "cancelling"]: - return batch - - if batch.status in ["completed", "failed", "expired"]: - raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") - - await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) - - if batch_id in self._processing_tasks: - self._processing_tasks[batch_id].cancel() - # note: task removal and status="cancelled" handled in finally block of _process_batch - - return await self.retrieve_batch(batch_id) - - async def list_batches( - self, - after: str | None = None, - limit: int = 20, - ) -> ListBatchesResponse: - """ - List all batches, eventually only for the current user. - - With no notion of user, we return all batches. - """ - batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff") - - batches = [] - for batch_data in batch_values: - if batch_data: - batches.append(BatchObject.model_validate_json(batch_data)) - - batches.sort(key=lambda b: b.created_at, reverse=True) - - start_idx = 0 - if after: - for i, batch in enumerate(batches): - if batch.id == after: - start_idx = i + 1 - break - - page_batches = batches[start_idx : start_idx + limit] - has_more = (start_idx + limit) < len(batches) - - first_id = page_batches[0].id if page_batches else None - last_id = page_batches[-1].id if page_batches else None - - return ListBatchesResponse( - data=page_batches, - first_id=first_id, - last_id=last_id, - has_more=has_more, - ) - - async def retrieve_batch(self, batch_id: str) -> BatchObject: - """Retrieve information about a specific batch.""" - batch_data = await self.kvstore.get(f"batch:{batch_id}") - if not batch_data: - raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") - - return BatchObject.model_validate_json(batch_data) - - async def _update_batch(self, batch_id: str, **updates) -> None: - """Update batch fields in kvstore.""" - async with self._update_batch_lock: - try: - batch = await self.retrieve_batch(batch_id) - - # batch processing is async. once cancelling, only allow "cancelled" status updates - if batch.status == "cancelling" and updates.get("status") != "cancelled": - logger.info( - f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}" - ) - return - - if "errors" in updates: - updates["errors"] = updates["errors"].model_dump() - - batch_dict = batch.model_dump() - batch_dict.update(updates) - - await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict)) - except Exception as e: - logger.error(f"Failed to update batch {batch_id}: {e}") - - async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]: - """ - Read & validate input, return errors and valid input. - - Validation of - - input_file_id existance - - valid json - - custom_id, method, url, body presence and valid - - no streaming - """ - requests: list[BatchRequest] = [] - errors: list[BatchError] = [] - try: - await self.files_api.openai_retrieve_file(batch.input_file_id) - except Exception: - errors.append( - BatchError( - code="invalid_request", - line=None, - message=f"Cannot find file {batch.input_file_id}.", - param="input_file_id", - ) - ) - return errors, requests - - # TODO(SECURITY): do something about large files - file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) - file_content = file_content_response.body.decode("utf-8") - for line_num, line in enumerate(file_content.strip().split("\n"), 1): - if line.strip(): # skip empty lines - try: - request = json.loads(line) - - if not isinstance(request, dict): - errors.append( - BatchError( - code="invalid_request", - line=line_num, - message="Each line must be a JSON dictionary object", - ) - ) - continue - - valid = True - - for param, expected_type, type_string in [ - ("custom_id", str, "string"), - ("method", str, "string"), - ("url", str, "string"), - ("body", dict, "JSON dictionary object"), - ]: - if param not in request: - errors.append( - BatchError( - code="missing_required_parameter", - line=line_num, - message=f"Missing required parameter: {param}", - param=param, - ) - ) - valid = False - elif not isinstance(request[param], expected_type): - param_name = "URL" if param == "url" else param.capitalize() - errors.append( - BatchError( - code="invalid_request", - line=line_num, - message=f"{param_name} must be a {type_string}", - param=param, - ) - ) - valid = False - - if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint: - errors.append( - BatchError( - code="invalid_url", - line=line_num, - message="URL provided for this request does not match the batch endpoint", - param="url", - ) - ) - valid = False - - if (body := request.get("body")) and isinstance(body, dict): - if body.get("stream", False): - errors.append( - BatchError( - code="streaming_unsupported", - line=line_num, - message="Streaming is not supported in batch processing", - param="body.stream", - ) - ) - valid = False - - for param, expected_type, type_string in [ - ("model", str, "a string"), - # messages is specific to /v1/chat/completions - # we could skip validating messages here and let inference fail. however, - # that would be a very expensive way to find out messages is wrong. - ("messages", list, "an array"), # TODO: allow messages to be a string? - ]: - if param not in body: - errors.append( - BatchError( - code="invalid_request", - line=line_num, - message=f"{param.capitalize()} parameter is required", - param=f"body.{param}", - ) - ) - valid = False - elif not isinstance(body[param], expected_type): - errors.append( - BatchError( - code="invalid_request", - line=line_num, - message=f"{param.capitalize()} must be {type_string}", - param=f"body.{param}", - ) - ) - valid = False - - if "model" in body and isinstance(body["model"], str): - try: - await self.models_api.get_model(body["model"]) - except Exception: - errors.append( - BatchError( - code="model_not_found", - line=line_num, - message=f"Model '{body['model']}' does not exist or is not supported", - param="body.model", - ) - ) - valid = False - - if valid: - assert isinstance(url, str), "URL must be a string" # for mypy - assert isinstance(body, dict), "Body must be a dictionary" # for mypy - requests.append( - BatchRequest( - line_num=line_num, - url=url, - method=request["method"], - custom_id=request["custom_id"], - body=body, - ), - ) - except json.JSONDecodeError: - errors.append( - BatchError( - code="invalid_json_line", - line=line_num, - message="This line is not parseable as valid JSON.", - ) - ) - - return errors, requests - - async def _process_batch(self, batch_id: str) -> None: - """Background task to process a batch of requests.""" - try: - logger.info(f"Starting batch processing for {batch_id}") - async with self._batch_semaphore: # semaphore to limit concurrency - logger.info(f"Acquired semaphore for batch {batch_id}") - await self._process_batch_impl(batch_id) - except asyncio.CancelledError: - logger.info(f"Batch processing cancelled for {batch_id}") - await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time())) - except Exception as e: - logger.error(f"Batch processing failed for {batch_id}: {e}") - await self._update_batch( - batch_id, - status="failed", - failed_at=int(time.time()), - errors=Errors(data=[BatchError(code="internal_error", message=str(e))]), - ) - finally: - self._processing_tasks.pop(batch_id, None) - - async def _process_batch_impl(self, batch_id: str) -> None: - """Implementation of batch processing logic.""" - errors: list[BatchError] = [] - batch = await self.retrieve_batch(batch_id) - - errors, requests = await self._validate_input(batch) - if errors: - await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)) - logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors") - return - - logger.info(f"Processing {len(requests)} requests for batch {batch_id}") - - total_requests = len(requests) - await self._update_batch( - batch_id, - status="in_progress", - request_counts={"total": total_requests, "completed": 0, "failed": 0}, - ) - - error_results = [] - success_results = [] - completed_count = 0 - failed_count = 0 - - for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch): - # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled - async with asyncio.TaskGroup() as tg: - chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk] - - chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True) - - for result in chunk_results: - if isinstance(result, dict) and result.get("error") is not None: # error response from inference - failed_count += 1 - error_results.append(result) - elif isinstance(result, dict) and result.get("response") is not None: # successful inference - completed_count += 1 - success_results.append(result) - else: # unexpected result - failed_count += 1 - errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}")) - - await self._update_batch( - batch_id, - request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count}, - ) - - if errors: - await self._update_batch( - batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors) - ) - return - - try: - output_file_id = await self._create_output_file(batch_id, success_results, "success") - await self._update_batch(batch_id, output_file_id=output_file_id) - - error_file_id = await self._create_output_file(batch_id, error_results, "error") - await self._update_batch(batch_id, error_file_id=error_file_id) - - await self._update_batch(batch_id, status="completed", completed_at=int(time.time())) - - logger.info( - f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed" - ) - except Exception as e: - # note: errors is empty at this point, so we don't lose anything by ignoring it - await self._update_batch( - batch_id, - status="failed", - failed_at=int(time.time()), - errors=Errors(data=[BatchError(code="output_failed", message=str(e))]), - ) - - async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict: - """Process a single request from the batch.""" - request_id = f"batch_req_{batch_id}_{request.line_num}" - - try: - # TODO(SECURITY): review body for security issues - chat_response = await self.inference_api.openai_chat_completion(**request.body) - - # this is for mypy, we don't allow streaming so we'll get the right type - assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" - return { - "id": request_id, - "custom_id": request.custom_id, - "response": { - "status_code": 200, - "request_id": request_id, # TODO: should this be different? - "body": chat_response.model_dump_json(), - }, - } - except Exception as e: - logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") - return { - "id": request_id, - "custom_id": request.custom_id, - "error": {"type": "request_failed", "message": str(e)}, - } - - async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str: - """ - Create an output file with batch results. - - This function filters results based on the specified file_type - and uploads the file to the Files API. - """ - output_lines = [json.dumps(result) for result in results] - - with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer: - file_buffer.filename = f"{batch_id}_{file_type}.jsonl" - uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) - return uploaded_file.id diff --git a/llama_stack/providers/inline/batches/reference/config.py b/llama_stack/providers/inline/batches/reference/config.py deleted file mode 100644 index d8d06868b..000000000 --- a/llama_stack/providers/inline/batches/reference/config.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pydantic import BaseModel, Field - -from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig - - -class ReferenceBatchesImplConfig(BaseModel): - """Configuration for the Reference Batches implementation.""" - - kvstore: KVStoreConfig = Field( - description="Configuration for the key-value store backend.", - ) - - max_concurrent_batches: int = Field( - default=1, - description="Maximum number of concurrent batches to process simultaneously.", - ge=1, - ) - - max_concurrent_requests_per_batch: int = Field( - default=10, - description="Maximum number of concurrent requests to process per batch.", - ge=1, - ) - - # TODO: add a max requests per second rate limiter - - @classmethod - def sample_run_config(cls, __distro_dir__: str) -> dict: - return { - "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="batches.db", - ), - } diff --git a/llama_stack/providers/registry/batches.py b/llama_stack/providers/registry/batches.py deleted file mode 100644 index de7886efb..000000000 --- a/llama_stack/providers/registry/batches.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec - - -def available_providers() -> list[ProviderSpec]: - return [ - InlineProviderSpec( - api=Api.batches, - provider_type="inline::reference", - pip_packages=["openai"], - module="llama_stack.providers.inline.batches.reference", - config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig", - api_dependencies=[ - Api.inference, - Api.files, - Api.models, - ], - description="Reference implementation of batches API with KVStore persistence.", - ), - ] diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index 060acfa72..717677c52 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -18,23 +18,6 @@ from llama_stack.core.distribution import get_provider_registry REPO_ROOT = Path(__file__).parent.parent -def get_api_docstring(api_name: str) -> str | None: - """Extract docstring from the API protocol class.""" - try: - # Import the API module dynamically - api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()]) - - # Get the main protocol class (usually capitalized API name) - protocol_class_name = api_name.title() - if hasattr(api_module, protocol_class_name): - protocol_class = getattr(api_module, protocol_class_name) - return protocol_class.__doc__ - except (ImportError, AttributeError): - pass - - return None - - class ChangedPathTracker: """Track a list of paths we may have changed.""" @@ -278,11 +261,6 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N index_content.append(f"# {api_name.title()}\n") index_content.append("## Overview\n") - api_docstring = get_api_docstring(api_name) - if api_docstring: - cleaned_docstring = api_docstring.strip() - index_content.append(f"{cleaned_docstring}\n") - index_content.append( f"This section contains documentation for all available providers for the **{api_name}** API.\n" ) diff --git a/tests/integration/batches/__init__.py b/tests/integration/batches/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/tests/integration/batches/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/tests/integration/batches/conftest.py b/tests/integration/batches/conftest.py deleted file mode 100644 index 974fe77ab..000000000 --- a/tests/integration/batches/conftest.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -"""Shared pytest fixtures for batch tests.""" - -import json -import time -import warnings -from contextlib import contextmanager -from io import BytesIO - -import pytest - -from llama_stack.apis.files import OpenAIFilePurpose - - -class BatchHelper: - """Helper class for creating and managing batch input files.""" - - def __init__(self, client): - """Initialize with either a batch_client or openai_client.""" - self.client = client - - @contextmanager - def create_file(self, content: str | list[dict], filename_prefix="batch_input"): - """Context manager for creating and cleaning up batch input files. - - Args: - content: Either a list of batch request dictionaries or raw string content - filename_prefix: Prefix for the generated filename (or full filename if content is string) - - Yields: - The uploaded file object - """ - if isinstance(content, str): - # Handle raw string content (e.g., malformed JSONL, empty files) - file_content = content.encode("utf-8") - else: - # Handle list of batch request dictionaries - jsonl_content = "\n".join(json.dumps(req) for req in content) - file_content = jsonl_content.encode("utf-8") - - filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl" - - with BytesIO(file_content) as file_buffer: - file_buffer.name = filename - uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) - - try: - yield uploaded_file - finally: - try: - self.client.files.delete(uploaded_file.id) - except Exception: - warnings.warn( - f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}", - stacklevel=2, - ) - - def wait_for( - self, - batch_id: str, - max_wait_time: int = 60, - sleep_interval: int | None = None, - expected_statuses: set[str] | None = None, - timeout_action: str = "fail", - ): - """Wait for a batch to reach a terminal status. - - Args: - batch_id: The batch ID to monitor - max_wait_time: Maximum time to wait in seconds (default: 60 seconds) - sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s) - expected_statuses: Set of expected terminal statuses (default: {"completed"}) - timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip) - - Returns: - The final batch object - - Raises: - pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail" - pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status - """ - if sleep_interval is None: - # Default to 1/10th of max_wait_time, with min 1s and max 15s - sleep_interval = max(1, min(15, max_wait_time // 10)) - - if expected_statuses is None: - expected_statuses = {"completed"} - - terminal_statuses = {"completed", "failed", "cancelled", "expired"} - unexpected_statuses = terminal_statuses - expected_statuses - - start_time = time.time() - while time.time() - start_time < max_wait_time: - current_batch = self.client.batches.retrieve(batch_id) - - if current_batch.status in expected_statuses: - return current_batch - elif current_batch.status in unexpected_statuses: - error_msg = f"Batch reached unexpected status: {current_batch.status}" - if timeout_action == "skip": - pytest.skip(error_msg) - else: - pytest.fail(error_msg) - - time.sleep(sleep_interval) - - timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds" - if timeout_action == "skip": - pytest.skip(timeout_msg) - else: - pytest.fail(timeout_msg) - - -@pytest.fixture -def batch_helper(openai_client): - """Fixture that provides a BatchHelper instance for OpenAI client.""" - return BatchHelper(openai_client) diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py deleted file mode 100644 index 1ef3202d0..000000000 --- a/tests/integration/batches/test_batches.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Integration tests for the Llama Stack batch processing functionality. - -This module contains comprehensive integration tests for the batch processing API, -using the OpenAI-compatible client interface for consistency. - -Test Categories: - 1. Core Batch Operations: - - test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval - - test_batch_listing: Basic batch listing functionality - - test_batch_immediate_cancellation: Batch cancellation workflow - # TODO: cancel during processing - - 2. End-to-End Processing: - - test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation - -Note: Error conditions and edge cases are primarily tested in test_batches_errors.py -for better organization and separation of concerns. - -CLEANUP WARNING: These tests currently create batches that are not automatically -cleaned up after test completion. This may lead to resource accumulation over -multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch. -The test_batch_e2e_chat_completions test does clean up its output and error files. -""" - -import json - - -class TestBatchesIntegration: - """Integration tests for the batches API.""" - - def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id): - """Test comprehensive batch creation and retrieval scenarios.""" - test_metadata = { - "test_type": "comprehensive", - "purpose": "creation_and_retrieval_test", - "version": "1.0", - "tags": "test,batch", - } - - batch_requests = [ - { - "custom_id": "request-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata=test_metadata, - ) - - assert batch.endpoint == "/v1/chat/completions" - assert batch.input_file_id == uploaded_file.id - assert batch.completion_window == "24h" - assert batch.metadata == test_metadata - - retrieved_batch = openai_client.batches.retrieve(batch.id) - - assert retrieved_batch.id == batch.id - assert retrieved_batch.object == batch.object - assert retrieved_batch.endpoint == batch.endpoint - assert retrieved_batch.input_file_id == batch.input_file_id - assert retrieved_batch.completion_window == batch.completion_window - assert retrieved_batch.metadata == batch.metadata - - def test_batch_listing(self, openai_client, batch_helper, text_model_id): - """ - Test batch listing. - - This test creates multiple batches and verifies that they can be listed. - It also deletes the input files before execution, which means the batches - will appear as failed due to missing input files. This is expected and - a good thing, because it means no inference is performed. - """ - batch_ids = [] - - for i in range(2): - batch_requests = [ - { - "custom_id": f"request-{i}", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": f"Hello {i}"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - batch_ids.append(batch.id) - - batch_list = openai_client.batches.list() - - assert isinstance(batch_list.data, list) - - listed_batch_ids = {b.id for b in batch_list.data} - for batch_id in batch_ids: - assert batch_id in listed_batch_ids - - def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id): - """Test immediate batch cancellation.""" - batch_requests = [ - { - "custom_id": "request-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - # hopefully cancel the batch before it completes - cancelling_batch = openai_client.batches.cancel(batch.id) - assert cancelling_batch.status in ["cancelling", "cancelled"] - assert isinstance(cancelling_batch.cancelling_at, int), ( - f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}" - ) - - final_batch = batch_helper.wait_for( - batch.id, - max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min - expected_statuses={"cancelled"}, - timeout_action="skip", - ) - - assert final_batch.status == "cancelled" - assert isinstance(final_batch.cancelled_at, int), ( - f"cancelled_at should be int, got {type(final_batch.cancelled_at)}" - ) - - def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id): - """Test end-to-end batch processing for chat completions with both successful and failed operations.""" - batch_requests = [ - { - "custom_id": "success-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Say hello"}], - "max_tokens": 20, - }, - }, - { - "custom_id": "error-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "This should fail"}], - "max_tokens": -1, # Invalid negative max_tokens will cause inference error - }, - }, - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={"test": "e2e_success_and_errors_test"}, - ) - - final_batch = batch_helper.wait_for( - batch.id, - max_wait_time=3 * 60, # often takes 2-3 minutes - expected_statuses={"completed"}, - timeout_action="skip", - ) - - # Expecting a completed batch with both successful and failed requests - # Batch(id='batch_xxx', - # completion_window='24h', - # created_at=..., - # endpoint='/v1/chat/completions', - # input_file_id='file-xxx', - # object='batch', - # status='completed', - # output_file_id='file-xxx', - # error_file_id='file-xxx', - # request_counts=BatchRequestCounts(completed=1, failed=1, total=2)) - - assert final_batch.status == "completed" - assert final_batch.request_counts is not None - assert final_batch.request_counts.total == 2 - assert final_batch.request_counts.completed == 1 - assert final_batch.request_counts.failed == 1 - - assert final_batch.output_file_id is not None, "Output file should exist for successful requests" - - output_content = openai_client.files.content(final_batch.output_file_id) - if isinstance(output_content, str): - output_text = output_content - else: - output_text = output_content.content.decode("utf-8") - - output_lines = output_text.strip().split("\n") - - for line in output_lines: - result = json.loads(line) - - assert "id" in result - assert "custom_id" in result - assert result["custom_id"] == "success-1" - - assert "response" in result - - assert result["response"]["status_code"] == 200 - assert "body" in result["response"] - assert "choices" in result["response"]["body"] - - assert final_batch.error_file_id is not None, "Error file should exist for failed requests" - - error_content = openai_client.files.content(final_batch.error_file_id) - if isinstance(error_content, str): - error_text = error_content - else: - error_text = error_content.content.decode("utf-8") - - error_lines = error_text.strip().split("\n") - - for line in error_lines: - result = json.loads(line) - - assert "id" in result - assert "custom_id" in result - assert result["custom_id"] == "error-1" - assert "error" in result - error = result["error"] - assert error is not None - assert "code" in error or "message" in error, "Error should have code or message" - - deleted_output_file = openai_client.files.delete(final_batch.output_file_id) - assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully" - - deleted_error_file = openai_client.files.delete(final_batch.error_file_id) - assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" diff --git a/tests/integration/batches/test_batches_errors.py b/tests/integration/batches/test_batches_errors.py deleted file mode 100644 index bc94a182e..000000000 --- a/tests/integration/batches/test_batches_errors.py +++ /dev/null @@ -1,693 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Error handling and edge case tests for the Llama Stack batch processing functionality. - -This module focuses exclusively on testing error conditions, validation failures, -and edge cases for batch operations to ensure robust error handling and graceful -degradation. - -Test Categories: - 1. File and Input Validation: - - test_batch_nonexistent_file_id: Handling invalid file IDs - - test_batch_malformed_jsonl: Processing malformed JSONL input files - - test_file_malformed_batch_file: Handling malformed files at upload time - - test_batch_missing_required_fields: Validation of required request fields - - 2. API Endpoint and Model Validation: - - test_batch_invalid_endpoint: Invalid endpoint handling during creation - - test_batch_error_handling_invalid_model: Error handling with nonexistent models - - test_batch_endpoint_mismatch: Validation of endpoint/URL consistency - - 3. Batch Lifecycle Error Handling: - - test_batch_retrieve_nonexistent: Retrieving non-existent batches - - test_batch_cancel_nonexistent: Cancelling non-existent batches - - test_batch_cancel_completed: Attempting to cancel completed batches - - 4. Parameter and Configuration Validation: - - test_batch_invalid_completion_window: Invalid completion window values - - test_batch_invalid_metadata_types: Invalid metadata type validation - - test_batch_missing_required_body_fields: Validation of required fields in request body - - 5. Feature Restriction and Compatibility: - - test_batch_streaming_not_supported: Streaming request rejection - - test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation - -Note: Core functionality and OpenAI compatibility tests are located in -test_batches_integration.py for better organization and separation of concerns. - -CLEANUP WARNING: These tests create batches to test error conditions but do not -automatically clean them up after test completion. While most error tests create -batches that fail quickly, some may create valid batches that consume resources. -""" - -import pytest -from openai import BadRequestError, ConflictError, NotFoundError - - -class TestBatchesErrorHandling: - """Error handling and edge case tests for the batches API using OpenAI client.""" - - def test_batch_nonexistent_file_id(self, openai_client, batch_helper): - """Test batch creation with nonexistent input file ID.""" - - batch = openai_client.batches.create( - input_file_id="file-nonexistent-xyz", - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError( - # code='invalid_request', - # line=None, - # message='Cannot find file ..., or organization ... does not have access to it.', - # param='file_id') - # ], object='list'), - # failed_at=1754566971, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.code == "invalid_request" - assert "cannot find file" in error.message.lower() - - def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id): - """Test batch creation with invalid endpoint.""" - batch_requests = [ - { - "custom_id": "invalid-endpoint", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - with pytest.raises(BadRequestError) as exc_info: - openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/invalid/endpoint", - completion_window="24h", - ) - - # Expected - - # Error code: 400 - { - # 'error': { - # 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.", - # 'type': 'invalid_request_error', - # 'param': 'endpoint', - # 'code': 'invalid_value' - # } - # } - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 400 - assert "invalid value" in error_msg - assert "/v1/invalid/endpoint" in error_msg - assert "supported values" in error_msg - assert "endpoint" in error_msg - assert "invalid_value" in error_msg - - def test_batch_malformed_jsonl(self, openai_client, batch_helper): - """ - Test batch with malformed JSONL input. - - The /v1/files endpoint requires valid JSONL format, so we provide a well formed line - before a malformed line to ensure we get to the /v1/batches validation stage. - """ - with batch_helper.create_file( - """{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}} -{invalid json here""", - "malformed_batch_input.jsonl", - ) as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # ..., - # BatchError(code='invalid_json_line', - # line=2, - # message='This line is not parseable as valid JSON.', - # param=None) - # ], object='list'), - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) > 0 - error = final_batch.errors.data[-1] # get last error because first may be about the "test" model - assert error.code == "invalid_json_line" - assert error.line == 2 - assert "not" in error.message.lower() - assert "valid json" in error.message.lower() - - @pytest.mark.xfail(reason="Not all file providers validate content") - @pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"]) - def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests): - """Test file upload with malformed content.""" - - with pytest.raises(BadRequestError) as exc_info: - with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"): - # /v1/files rejects the file, we don't get to batch creation - pass - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 400 - assert "invalid file format" in error_msg - assert "jsonl" in error_msg - - def test_batch_retrieve_nonexistent(self, openai_client): - """Test retrieving nonexistent batch.""" - with pytest.raises(NotFoundError) as exc_info: - openai_client.batches.retrieve("batch-nonexistent-xyz") - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 404 - assert "no batch found" in error_msg or "not found" in error_msg - - def test_batch_cancel_nonexistent(self, openai_client): - """Test cancelling nonexistent batch.""" - with pytest.raises(NotFoundError) as exc_info: - openai_client.batches.cancel("batch-nonexistent-xyz") - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 404 - assert "no batch found" in error_msg or "not found" in error_msg - - def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id): - """Test cancelling already completed batch.""" - batch_requests = [ - { - "custom_id": "cancel-completed", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Quick test"}], - "max_tokens": 5, - }, - } - ] - - with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for( - batch.id, - max_wait_time=3 * 60, # often take 10-11 min, give it 3 min - expected_statuses={"completed"}, - timeout_action="skip", - ) - - deleted_file = openai_client.files.delete(final_batch.output_file_id) - assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully" - - with pytest.raises(ConflictError) as exc_info: - openai_client.batches.cancel(batch.id) - - # Expecting - - # Error code: 409 - { - # 'error': { - # 'message': "Cannot cancel a batch with status 'completed'.", - # 'type': 'invalid_request_error', - # 'param': None, - # 'code': None - # } - # } - # - # NOTE: Same for "failed", cancelling "cancelled" batches is allowed - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 409 - assert "cannot cancel" in error_msg - - def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id): - """Test batch with requests missing required fields.""" - batch_requests = [ - { - # Missing custom_id - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "No custom_id"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "no-method", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "No method"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "no-url", - "method": "POST", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "No URL"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "no-body", - "method": "POST", - "url": "/v1/chat/completions", - }, - ] - - with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors( - # data=[ - # BatchError( - # code='missing_required_parameter', - # line=1, - # message="Missing required parameter: 'custom_id'.", - # param='custom_id' - # ), - # BatchError( - # code='missing_required_parameter', - # line=2, - # message="Missing required parameter: 'method'.", - # param='method' - # ), - # BatchError( - # code='missing_required_parameter', - # line=3, - # message="Missing required parameter: 'url'.", - # param='url' - # ), - # BatchError( - # code='missing_required_parameter', - # line=4, - # message="Missing required parameter: 'body'.", - # param='body' - # ) - # ], object='list'), - # failed_at=1754566945, - # ...) - # ) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 4 - no_custom_id_error = final_batch.errors.data[0] - assert no_custom_id_error.code == "missing_required_parameter" - assert no_custom_id_error.line == 1 - assert "missing" in no_custom_id_error.message.lower() - assert "custom_id" in no_custom_id_error.message.lower() - no_method_error = final_batch.errors.data[1] - assert no_method_error.code == "missing_required_parameter" - assert no_method_error.line == 2 - assert "missing" in no_method_error.message.lower() - assert "method" in no_method_error.message.lower() - no_url_error = final_batch.errors.data[2] - assert no_url_error.code == "missing_required_parameter" - assert no_url_error.line == 3 - assert "missing" in no_url_error.message.lower() - assert "url" in no_url_error.message.lower() - no_body_error = final_batch.errors.data[3] - assert no_body_error.code == "missing_required_parameter" - assert no_body_error.line == 4 - assert "missing" in no_body_error.message.lower() - assert "body" in no_body_error.message.lower() - - def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id): - """Test batch creation with invalid completion window.""" - batch_requests = [ - { - "custom_id": "invalid-completion-window", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - for window in ["1h", "48h", "invalid", ""]: - with pytest.raises(BadRequestError) as exc_info: - openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window=window, - ) - assert exc_info.value.status_code == 400 - error_msg = str(exc_info.value).lower() - assert "error" in error_msg - assert "completion_window" in error_msg - - def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id): - """Test that streaming responses are not supported in batches.""" - batch_requests = [ - { - "custom_id": "streaming-test", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - "stream": True, # Not supported - }, - } - ] - - with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError(code='streaming_unsupported', - # line=1, - # message='Chat Completions: Streaming is not supported in the Batch API.', - # param='body.stream') - # ], object='list'), - # failed_at=1754566965, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.code == "streaming_unsupported" - assert error.line == 1 - assert "streaming" in error.message.lower() - assert "not supported" in error.message.lower() - assert error.param == "body.stream" - assert final_batch.failed_at is not None - - def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id): - """ - Test batch with mixed streaming and non-streaming requests. - - This is distinct from test_batch_streaming_not_supported, which tests a single - streaming request, to ensure an otherwise valid batch fails when a single - streaming request is included. - """ - batch_requests = [ - { - "custom_id": "valid-non-streaming-request", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello without streaming"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "streaming-request", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello with streaming"}], - "max_tokens": 10, - "stream": True, # Not supported - }, - }, - ] - - with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError( - # code='streaming_unsupported', - # line=2, - # message='Chat Completions: Streaming is not supported in the Batch API.', - # param='body.stream') - # ], object='list'), - # failed_at=1754574442, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.code == "streaming_unsupported" - assert error.line == 2 - assert "streaming" in error.message.lower() - assert "not supported" in error.message.lower() - assert error.param == "body.stream" - assert final_batch.failed_at is not None - - def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id): - """Test batch creation with mismatched endpoint and request URL.""" - batch_requests = [ - { - "custom_id": "endpoint-mismatch", - "method": "POST", - "url": "/v1/embeddings", # Different from batch endpoint - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - }, - } - ] - - with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", # Different from request URL - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError( - # code='invalid_url', - # line=1, - # message='The URL provided for this request does not match the batch endpoint.', - # param='url') - # ], object='list'), - # failed_at=1754566972, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.line == 1 - assert error.code == "invalid_url" - assert "does not match" in error.message.lower() - assert "endpoint" in error.message.lower() - assert final_batch.failed_at is not None - - def test_batch_error_handling_invalid_model(self, openai_client, batch_helper): - """Test batch error handling with invalid model.""" - batch_requests = [ - { - "custom_id": "invalid-model", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": "nonexistent-model-xyz", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError(code='model_not_found', - # line=1, - # message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.", - # param='body.model') - # ], object='list'), - # failed_at=1754566978, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.line == 1 - assert error.code == "model_not_found" - assert "not supported" in error.message.lower() - assert error.param == "body.model" - assert final_batch.failed_at is not None - - def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id): - """Test batch with requests missing required fields in body (model and messages).""" - batch_requests = [ - { - "custom_id": "missing-model", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - # Missing model field - "messages": [{"role": "user", "content": "Hello without model"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "missing-messages", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - # Missing messages field - "max_tokens": 10, - }, - }, - ] - - with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError( - # code='invalid_request', - # line=1, - # message='Model parameter is required.', - # param='body.model'), - # BatchError( - # code='invalid_request', - # line=2, - # message='Messages parameter is required.', - # param='body.messages') - # ], object='list'), - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 2 - - model_error = final_batch.errors.data[0] - assert model_error.line == 1 - assert "model" in model_error.message.lower() - assert model_error.param == "body.model" - - messages_error = final_batch.errors.data[1] - assert messages_error.line == 2 - assert "messages" in messages_error.message.lower() - assert messages_error.param == "body.messages" - - assert final_batch.failed_at is not None - - def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id): - """Test batch creation with invalid metadata types (like lists).""" - batch_requests = [ - { - "custom_id": "invalid-metadata-type", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - with pytest.raises(Exception) as exc_info: - openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={ - "tags": ["tag1", "tag2"], # Invalid type, should be a string - }, - ) - - # Expecting - - # Error code: 400 - {'error': - # {'message': "Invalid type for 'metadata.tags': expected a string, - # but got an array instead.", - # 'type': 'invalid_request_error', 'param': 'metadata.tags', - # 'code': 'invalid_type'}} - - error_msg = str(exc_info.value).lower() - assert "400" in error_msg - assert "tags" in error_msg - assert "string" in error_msg diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py deleted file mode 100644 index 9fe0cc710..000000000 --- a/tests/unit/providers/batches/test_reference.py +++ /dev/null @@ -1,753 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Test suite for the reference implementation of the Batches API. - -The tests are categorized and outlined below, keep this updated: - -- Batch creation with various parameters and validation: - * test_create_and_retrieve_batch_success (positive) - * test_create_batch_without_metadata (positive) - * test_create_batch_completion_window (negative) - * test_create_batch_invalid_endpoints (negative) - * test_create_batch_invalid_metadata (negative) - -- Batch retrieval and error handling for non-existent batches: - * test_retrieve_batch_not_found (negative) - -- Batch cancellation with proper status transitions: - * test_cancel_batch_success (positive) - * test_cancel_batch_invalid_statuses (negative) - * test_cancel_batch_not_found (negative) - -- Batch listing with pagination and filtering: - * test_list_batches_empty (positive) - * test_list_batches_single_batch (positive) - * test_list_batches_multiple_batches (positive) - * test_list_batches_with_limit (positive) - * test_list_batches_with_pagination (positive) - * test_list_batches_invalid_after (negative) - -- Data persistence in the underlying key-value store: - * test_kvstore_persistence (positive) - -- Batch processing concurrency control: - * test_max_concurrent_batches (positive) - -- Input validation testing (direct _validate_input method tests): - * test_validate_input_file_not_found (negative) - * test_validate_input_file_exists_empty_content (positive) - * test_validate_input_file_mixed_valid_invalid_json (mixed) - * test_validate_input_invalid_model (negative) - * test_validate_input_url_mismatch (negative) - * test_validate_input_multiple_errors_per_request (negative) - * test_validate_input_invalid_request_format (negative) - * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) - * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) - -The tests use temporary SQLite databases for isolation and mock external -dependencies like inference, files, and models APIs. -""" - -import json -import tempfile -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from llama_stack.apis.batches import BatchObject -from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError -from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl -from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig - - -class TestReferenceBatchesImpl: - """Test the reference implementation of the Batches API.""" - - @pytest.fixture - async def provider(self): - """Create a test provider instance with temporary database.""" - with tempfile.TemporaryDirectory() as tmpdir: - db_path = Path(tmpdir) / "test_batches.db" - kvstore_config = SqliteKVStoreConfig(db_path=str(db_path)) - config = ReferenceBatchesImplConfig(kvstore=kvstore_config) - - # Create kvstore and mock APIs - from unittest.mock import AsyncMock - - from llama_stack.providers.utils.kvstore import kvstore_impl - - kvstore = await kvstore_impl(config.kvstore) - mock_inference = AsyncMock() - mock_files = AsyncMock() - mock_models = AsyncMock() - - provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore) - await provider.initialize() - - # unit tests should not require background processing - provider.process_batches = False - - yield provider - - await provider.shutdown() - - @pytest.fixture - def sample_batch_data(self): - """Sample batch data for testing.""" - return { - "input_file_id": "file_abc123", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "metadata": {"test": "true", "priority": "high"}, - } - - def _validate_batch_type(self, batch, expected_metadata=None): - """ - Helper function to validate batch object structure and field types. - - Note: This validates the direct BatchObject from the provider, not the - client library response which has a different structure. - - Args: - batch: The BatchObject instance to validate. - expected_metadata: Optional expected metadata dictionary to validate against. - """ - assert isinstance(batch.id, str) - assert isinstance(batch.completion_window, str) - assert isinstance(batch.created_at, int) - assert isinstance(batch.endpoint, str) - assert isinstance(batch.input_file_id, str) - assert batch.object == "batch" - assert batch.status in [ - "validating", - "failed", - "in_progress", - "finalizing", - "completed", - "expired", - "cancelling", - "cancelled", - ] - - if expected_metadata is not None: - assert batch.metadata == expected_metadata - - timestamp_fields = [ - "cancelled_at", - "cancelling_at", - "completed_at", - "expired_at", - "expires_at", - "failed_at", - "finalizing_at", - "in_progress_at", - ] - for field in timestamp_fields: - field_value = getattr(batch, field, None) - if field_value is not None: - assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}" - - file_id_fields = ["error_file_id", "output_file_id"] - for field in file_id_fields: - field_value = getattr(batch, field, None) - if field_value is not None: - assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}" - - if hasattr(batch, "request_counts") and batch.request_counts is not None: - assert isinstance(batch.request_counts.completed, int), ( - f"request_counts.completed should be int, got {type(batch.request_counts.completed)}" - ) - assert isinstance(batch.request_counts.failed, int), ( - f"request_counts.failed should be int, got {type(batch.request_counts.failed)}" - ) - assert isinstance(batch.request_counts.total, int), ( - f"request_counts.total should be int, got {type(batch.request_counts.total)}" - ) - - if hasattr(batch, "errors") and batch.errors is not None: - assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}" - - if hasattr(batch.errors, "data") and batch.errors.data is not None: - assert isinstance(batch.errors.data, list), ( - f"errors.data should be list or None, got {type(batch.errors.data)}" - ) - - for i, error_item in enumerate(batch.errors.data): - assert isinstance(error_item, dict), ( - f"errors.data[{i}] should be object or dict, got {type(error_item)}" - ) - - if hasattr(error_item, "code") and error_item.code is not None: - assert isinstance(error_item.code, str), ( - f"errors.data[{i}].code should be str or None, got {type(error_item.code)}" - ) - - if hasattr(error_item, "line") and error_item.line is not None: - assert isinstance(error_item.line, int), ( - f"errors.data[{i}].line should be int or None, got {type(error_item.line)}" - ) - - if hasattr(error_item, "message") and error_item.message is not None: - assert isinstance(error_item.message, str), ( - f"errors.data[{i}].message should be str or None, got {type(error_item.message)}" - ) - - if hasattr(error_item, "param") and error_item.param is not None: - assert isinstance(error_item.param, str), ( - f"errors.data[{i}].param should be str or None, got {type(error_item.param)}" - ) - - if hasattr(batch.errors, "object") and batch.errors.object is not None: - assert isinstance(batch.errors.object, str), ( - f"errors.object should be str or None, got {type(batch.errors.object)}" - ) - assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}" - - async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data): - """Test successful batch creation and retrieval.""" - created_batch = await provider.create_batch(**sample_batch_data) - - self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"]) - - assert created_batch.id.startswith("batch_") - assert len(created_batch.id) > 13 - assert created_batch.object == "batch" - assert created_batch.endpoint == sample_batch_data["endpoint"] - assert created_batch.input_file_id == sample_batch_data["input_file_id"] - assert created_batch.completion_window == sample_batch_data["completion_window"] - assert created_batch.status == "validating" - assert created_batch.metadata == sample_batch_data["metadata"] - assert isinstance(created_batch.created_at, int) - assert created_batch.created_at > 0 - - retrieved_batch = await provider.retrieve_batch(created_batch.id) - - self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"]) - - assert retrieved_batch.id == created_batch.id - assert retrieved_batch.input_file_id == created_batch.input_file_id - assert retrieved_batch.endpoint == created_batch.endpoint - assert retrieved_batch.status == created_batch.status - assert retrieved_batch.metadata == created_batch.metadata - - async def test_create_batch_without_metadata(self, provider): - """Test batch creation without optional metadata.""" - batch = await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h" - ) - - assert batch.metadata is None - - async def test_create_batch_completion_window(self, provider): - """Test batch creation with invalid completion window.""" - with pytest.raises(ValueError, match="Invalid completion_window"): - await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now" - ) - - @pytest.mark.parametrize( - "endpoint", - [ - "/v1/embeddings", - "/v1/completions", - "/v1/invalid/endpoint", - "", - ], - ) - async def test_create_batch_invalid_endpoints(self, provider, endpoint): - """Test batch creation with various invalid endpoints.""" - with pytest.raises(ValueError, match="Invalid endpoint"): - await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h") - - async def test_create_batch_invalid_metadata(self, provider): - """Test that batch creation fails with invalid metadata.""" - with pytest.raises(ValueError, match="should be a valid string"): - await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={123: "invalid_key"}, # Non-string key - ) - - with pytest.raises(ValueError, match="should be a valid string"): - await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={"valid_key": 456}, # Non-string value - ) - - async def test_retrieve_batch_not_found(self, provider): - """Test error when retrieving non-existent batch.""" - with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.retrieve_batch("nonexistent_batch") - - async def test_cancel_batch_success(self, provider, sample_batch_data): - """Test successful batch cancellation.""" - created_batch = await provider.create_batch(**sample_batch_data) - assert created_batch.status == "validating" - - cancelled_batch = await provider.cancel_batch(created_batch.id) - - assert cancelled_batch.id == created_batch.id - assert cancelled_batch.status in ["cancelling", "cancelled"] - assert isinstance(cancelled_batch.cancelling_at, int) - assert cancelled_batch.cancelling_at >= created_batch.created_at - - @pytest.mark.parametrize("status", ["failed", "expired", "completed"]) - async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status): - """Test error when cancelling batch in final states.""" - provider.process_batches = False - created_batch = await provider.create_batch(**sample_batch_data) - - # directly update status in kvstore - await provider._update_batch(created_batch.id, status=status) - - with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"): - await provider.cancel_batch(created_batch.id) - - async def test_cancel_batch_not_found(self, provider): - """Test error when cancelling non-existent batch.""" - with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.cancel_batch("nonexistent_batch") - - async def test_list_batches_empty(self, provider): - """Test listing batches when none exist.""" - response = await provider.list_batches() - - assert response.object == "list" - assert response.data == [] - assert response.first_id is None - assert response.last_id is None - assert response.has_more is False - - async def test_list_batches_single_batch(self, provider, sample_batch_data): - """Test listing batches with single batch.""" - created_batch = await provider.create_batch(**sample_batch_data) - - response = await provider.list_batches() - - assert len(response.data) == 1 - self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"]) - assert response.data[0].id == created_batch.id - assert response.first_id == created_batch.id - assert response.last_id == created_batch.id - assert response.has_more is False - - async def test_list_batches_multiple_batches(self, provider): - """Test listing multiple batches.""" - batches = [ - await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" - ) - for i in range(3) - ] - - response = await provider.list_batches() - - assert len(response.data) == 3 - - batch_ids = {batch.id for batch in response.data} - expected_ids = {batch.id for batch in batches} - assert batch_ids == expected_ids - assert response.has_more is False - - assert response.first_id in expected_ids - assert response.last_id in expected_ids - - async def test_list_batches_with_limit(self, provider): - """Test listing batches with limit parameter.""" - batches = [ - await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" - ) - for i in range(3) - ] - - response = await provider.list_batches(limit=2) - - assert len(response.data) == 2 - assert response.has_more is True - assert response.first_id == response.data[0].id - assert response.last_id == response.data[1].id - batch_ids = {batch.id for batch in response.data} - expected_ids = {batch.id for batch in batches} - assert batch_ids.issubset(expected_ids) - - async def test_list_batches_with_pagination(self, provider): - """Test listing batches with pagination using 'after' parameter.""" - for i in range(3): - await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" - ) - - # Get first page - first_page = await provider.list_batches(limit=1) - assert len(first_page.data) == 1 - assert first_page.has_more is True - - # Get second page using 'after' - second_page = await provider.list_batches(limit=1, after=first_page.data[0].id) - assert len(second_page.data) == 1 - assert second_page.data[0].id != first_page.data[0].id - - # Verify we got the next batch in order - all_batches = await provider.list_batches() - expected_second_batch_id = all_batches.data[1].id - assert second_page.data[0].id == expected_second_batch_id - - async def test_list_batches_invalid_after(self, provider, sample_batch_data): - """Test listing batches with invalid 'after' parameter.""" - await provider.create_batch(**sample_batch_data) - - response = await provider.list_batches(after="nonexistent_batch") - - # Should return all batches (no filtering when 'after' batch not found) - assert len(response.data) == 1 - - async def test_kvstore_persistence(self, provider, sample_batch_data): - """Test that batches are properly persisted in kvstore.""" - batch = await provider.create_batch(**sample_batch_data) - - stored_data = await provider.kvstore.get(f"batch:{batch.id}") - assert stored_data is not None - - stored_batch_dict = json.loads(stored_data) - assert stored_batch_dict["id"] == batch.id - assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"] - - async def test_validate_input_file_not_found(self, provider): - """Test _validate_input when input file does not exist.""" - provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found")) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="nonexistent_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - assert errors[0].code == "invalid_request" - assert errors[0].message == "Cannot find file nonexistent_file." - assert errors[0].param == "input_file_id" - assert errors[0].line is None - - async def test_validate_input_file_exists_empty_content(self, provider): - """Test _validate_input when file exists but is empty.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - mock_response.body = b"" - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="empty_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 0 - assert len(requests) == 0 - - async def test_validate_input_file_mixed_valid_invalid_json(self, provider): - """Test _validate_input when file contains valid and invalid JSON lines.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - # Line 1: valid JSON with proper body args, Line 2: invalid JSON - mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json' - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="mixed_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - # Should have 1 JSON parsing error from line 2, and 1 valid request from line 1 - assert len(errors) == 1 - assert len(requests) == 1 - - assert errors[0].code == "invalid_json_line" - assert errors[0].line == 2 - assert errors[0].message == "This line is not parseable as valid JSON." - - assert requests[0].custom_id == "req-1" - assert requests[0].method == "POST" - assert requests[0].url == "/v1/chat/completions" - assert requests[0].body["model"] == "test-model" - assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}] - - async def test_validate_input_invalid_model(self, provider): - """Test _validate_input when file contains request with non-existent model.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}' - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found")) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="invalid_model_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == "model_not_found" - assert errors[0].line == 1 - assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported" - assert errors[0].param == "body.model" - - @pytest.mark.parametrize( - "param_name,param_path,error_code,error_message", - [ - ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"), - ("method", "method", "missing_required_parameter", "Missing required parameter: method"), - ("url", "url", "missing_required_parameter", "Missing required parameter: url"), - ("body", "body", "missing_required_parameter", "Missing required parameter: body"), - ("model", "body.model", "invalid_request", "Model parameter is required"), - ("messages", "body.messages", "invalid_request", "Messages parameter is required"), - ], - ) - async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): - """Test _validate_input when file contains request with missing required parameters.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - - base_request = { - "custom_id": "req-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}, - } - - # Remove the specific parameter being tested - if "." in param_path: - top_level, nested_param = param_path.split(".", 1) - del base_request[top_level][nested_param] - else: - del base_request[param_name] - - mock_response.body = json.dumps(base_request).encode() - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id=f"missing_{param_name}_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == error_code - assert errors[0].line == 1 - assert errors[0].message == error_message - assert errors[0].param == param_path - - async def test_validate_input_url_mismatch(self, provider): - """Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}' - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", # This doesn't match the URL in the request - input_file_id="url_mismatch_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == "invalid_url" - assert errors[0].line == 1 - assert errors[0].message == "URL provided for this request does not match the batch endpoint" - assert errors[0].param == "url" - - async def test_validate_input_multiple_errors_per_request(self, provider): - """Test _validate_input when a single request has multiple validation errors.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - # Request missing custom_id, has invalid URL, and missing model in body - mock_response.body = ( - b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}' - ) - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request - input_file_id="multiple_errors_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) >= 2 # At least missing custom_id and URL mismatch - assert len(requests) == 0 - - for error in errors: - assert error.line == 1 - - error_codes = {error.code for error in errors} - assert "missing_required_parameter" in error_codes # missing custom_id - assert "invalid_url" in error_codes # URL mismatch - - async def test_validate_input_invalid_request_format(self, provider): - """Test _validate_input when file contains non-object JSON (array, string, number).""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - mock_response.body = b'["not", "a", "request", "object"]' - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="invalid_format_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == "invalid_request" - assert errors[0].line == 1 - assert errors[0].message == "Each line must be a JSON dictionary object" - - @pytest.mark.parametrize( - "param_name,param_path,invalid_value,error_message", - [ - ("custom_id", "custom_id", 12345, "Custom_id must be a string"), - ("url", "url", 123, "URL must be a string"), - ("method", "method", ["POST"], "Method must be a string"), - ("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"), - ("model", "body.model", 123, "Model must be a string"), - ("messages", "body.messages", "invalid messages format", "Messages must be an array"), - ], - ) - async def test_validate_input_invalid_parameter_types( - self, provider, param_name, param_path, invalid_value, error_message - ): - """Test _validate_input when file contains request with parameters that have invalid types.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - - base_request = { - "custom_id": "req-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}, - } - - # Override the specific parameter with invalid value - if "." in param_path: - top_level, nested_param = param_path.split(".", 1) - base_request[top_level][nested_param] = invalid_value - else: - base_request[param_name] = invalid_value - - mock_response.body = json.dumps(base_request).encode() - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id=f"invalid_{param_name}_type_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == "invalid_request" - assert errors[0].line == 1 - assert errors[0].message == error_message - assert errors[0].param == param_path - - async def test_max_concurrent_batches(self, provider): - """Test max_concurrent_batches configuration and concurrency control.""" - import asyncio - - provider._batch_semaphore = asyncio.Semaphore(2) - - provider.process_batches = True # enable because we're testing background processing - - active_batches = 0 - - async def add_and_wait(batch_id: str): - nonlocal active_batches - active_batches += 1 - await asyncio.sleep(float("inf")) - - # the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl, - # so we can replace _process_batch_impl with our mock to control concurrency - provider._process_batch_impl = add_and_wait - - for _ in range(3): - await provider.create_batch( - input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h" - ) - - await asyncio.sleep(0.042) # let tasks start - - assert active_batches == 2, f"Expected 2 active batches, got {active_batches}" From c15cc7ed77b7689e9fdf24cbda12a4511db21f89 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Thu, 14 Aug 2025 18:27:00 +0100 Subject: [PATCH 27/45] fix: use ChatCompletionMessageFunctionToolCall (#3142) The OpenAI compatibility layer was incorrectly importing ChatCompletionMessageToolCallParam instead of the ChatCompletionMessageFunctionToolCall class. This caused "Cannot instantiate typing.Union" errors when processing agent requests with tool calls. Closes: #3141 Signed-off-by: Derek Higgins --- .../utils/inference/openai_compat.py | 10 ++--- .../utils/inference/test_openai_compat.py | 40 +++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 9a77c8cc4..6297cc2ed 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -31,15 +31,15 @@ from openai.types.chat import ( from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) +from openai.types.chat import ( + ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall, +) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessage, ) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) from openai.types.chat import ( ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ) @@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new( ) elif isinstance(message, CompletionMessage): tool_calls = [ - OpenAIChatCompletionMessageToolCall( + OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), @@ -903,7 +903,7 @@ def _convert_openai_request_response_format( def _convert_openai_tool_calls( - tool_calls: list[OpenAIChatCompletionMessageToolCall], + tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall], ) -> list[ToolCall]: """ Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py index 5b8527d1b..ddc70e102 100644 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, + convert_message_to_openai_dict_new, openai_messages_to_messages, ) @@ -182,3 +183,42 @@ def test_user_message_accepts_images(): assert len(msg.content) == 2 assert msg.content[0].text == "Describe this image:" assert msg.content[1].image_url.url == "http://example.com/image.jpg" + + +async def test_convert_message_to_openai_dict_new_user_message(): + """Test convert_message_to_openai_dict_new with UserMessage.""" + message = UserMessage(content="Hello, world!", role="user") + result = await convert_message_to_openai_dict_new(message) + + assert result["role"] == "user" + assert result["content"] == "Hello, world!" + + +async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls(): + """Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls.""" + message = CompletionMessage( + content="I'll help you find the weather.", + tool_calls=[ + ToolCall( + call_id="call_123", + tool_name="get_weather", + arguments={"city": "Sligo"}, + arguments_json='{"city": "Sligo"}', + ) + ], + stop_reason=StopReason.end_of_turn, + ) + result = await convert_message_to_openai_dict_new(message) + + # This would have failed with "Cannot instantiate typing.Union" before the fix + assert result["role"] == "assistant" + assert result["content"] == "I'll help you find the weather." + assert "tool_calls" in result + assert result["tool_calls"] is not None + assert len(result["tool_calls"]) == 1 + + tool_call = result["tool_calls"][0] + assert tool_call.id == "call_123" + assert tool_call.type == "function" + assert tool_call.function.name == "get_weather" + assert tool_call.function.arguments == '{"city": "Sligo"}' From 61582f327cc44dad9e79b1f06e8fba516832908d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 14 Aug 2025 10:27:25 -0700 Subject: [PATCH 28/45] fix(ci): update triggers for the workflows (#3152) --- .github/workflows/integration-tests.yml | 4 ++-- .github/workflows/record-integration-tests.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f330d2c45..9ef49fba3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -5,7 +5,7 @@ run-name: Run the integration test suite from tests/integration in replay mode on: push: branches: [ main ] - pull_request_target: + pull_request: branches: [ main ] types: [opened, synchronize, reopened] paths: @@ -34,7 +34,7 @@ on: concurrency: # Skip concurrency for pushes to main - each commit should be tested independently - group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index 12957db27..b31709a4f 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -3,7 +3,7 @@ name: Integration Tests (Record) run-name: Run the integration test suite from tests/integration on: - pull_request: + pull_request_target: branches: [ main ] types: [opened, synchronize, labeled] paths: @@ -23,7 +23,7 @@ on: default: 'ollama' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} cancel-in-progress: true jobs: From e69acbafbfd902333ede09361c214fc82a8e895a Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Thu, 14 Aug 2025 15:58:43 -0600 Subject: [PATCH 29/45] feat(UI): Adding linter and prettier for UI (#3156) --- .pre-commit-config.yaml | 15 ++ llama_stack/ui/.nvmrc | 1 + llama_stack/ui/.prettierignore | 9 + llama_stack/ui/.prettierrc | 11 +- llama_stack/ui/app/api/v1/[...path]/route.ts | 4 +- llama_stack/ui/app/auth/signin/page.tsx | 4 +- llama_stack/ui/app/chat-playground/page.tsx | 192 ++++++++------ .../app/logs/chat-completions/[id]/page.tsx | 4 +- .../ui/app/logs/responses/[id]/page.tsx | 16 +- .../[fileId]/contents/[contentId]/page.tsx | 133 ++++++---- .../[id]/files/[fileId]/contents/page.tsx | 88 +++++-- .../[id]/files/[fileId]/page.tsx | 90 +++++-- .../ui/app/logs/vector-stores/[id]/page.tsx | 15 +- .../ui/app/logs/vector-stores/page.tsx | 137 +++++----- .../chat-completion-detail.test.tsx | 28 +-- .../chat-completion-detail.tsx | 29 ++- .../chat-completion-table.test.tsx | 34 +-- .../chat-completions-table.tsx | 5 +- .../chat-completions/chat-messasge-item.tsx | 35 +-- .../chat-playground/chat-message.tsx | 156 ++++++------ .../ui/components/chat-playground/chat.tsx | 211 ++++++++-------- .../chat-playground/interrupt-prompt.tsx | 12 +- .../chat-playground/markdown-renderer.tsx | 120 ++++----- .../chat-playground/message-input.tsx | 237 +++++++++--------- .../chat-playground/message-list.tsx | 20 +- .../chat-playground/prompt-suggestions.tsx | 10 +- .../chat-playground/typing-indicator.tsx | 4 +- .../ui/components/layout/app-sidebar.tsx | 103 ++++---- .../ui/components/layout/detail-layout.tsx | 2 +- .../logs/logs-table-scroll.test.tsx | 22 +- .../ui/components/logs/logs-table.test.tsx | 42 ++-- llama_stack/ui/components/logs/logs-table.tsx | 2 +- .../grouping/grouped-items-display.tsx | 2 +- .../responses/hooks/function-call-grouping.ts | 2 +- .../responses/items/item-renderer.tsx | 2 +- .../responses/items/message-item.tsx | 2 +- .../responses/responses-detail.test.tsx | 60 ++--- .../responses/responses-table.test.tsx | 34 +-- .../components/responses/responses-table.tsx | 20 +- .../components/responses/utils/item-types.ts | 10 +- .../ui/components/ui/audio-visualizer.tsx | 146 +++++------ llama_stack/ui/components/ui/breadcrumb.tsx | 2 +- llama_stack/ui/components/ui/button.tsx | 18 +- llama_stack/ui/components/ui/card.tsx | 6 +- llama_stack/ui/components/ui/collapsible.tsx | 12 +- llama_stack/ui/components/ui/copy-button.tsx | 20 +- .../ui/components/ui/dropdown-menu.tsx | 16 +- llama_stack/ui/components/ui/file-preview.tsx | 56 ++--- llama_stack/ui/components/ui/input.tsx | 2 +- llama_stack/ui/components/ui/select.tsx | 34 +-- llama_stack/ui/components/ui/separator.tsx | 2 +- llama_stack/ui/components/ui/sheet.tsx | 4 +- llama_stack/ui/components/ui/sidebar.tsx | 36 +-- llama_stack/ui/components/ui/sonner.tsx | 14 +- llama_stack/ui/components/ui/table.tsx | 8 +- llama_stack/ui/components/ui/tooltip.tsx | 2 +- .../vector-stores/vector-store-detail.tsx | 4 +- llama_stack/ui/e2e/logs-table-scroll.spec.ts | 2 +- llama_stack/ui/eslint.config.mjs | 8 +- llama_stack/ui/hooks/use-audio-recording.ts | 82 +++--- llama_stack/ui/hooks/use-auto-scroll.ts | 50 ++-- llama_stack/ui/hooks/use-autosize-textarea.ts | 32 +-- llama_stack/ui/hooks/use-copy-to-clipboard.ts | 34 +-- llama_stack/ui/hooks/use-infinite-scroll.ts | 6 +- llama_stack/ui/hooks/use-mobile.ts | 2 +- llama_stack/ui/hooks/use-pagination.ts | 10 +- llama_stack/ui/lib/audio-utils.ts | 54 ++-- llama_stack/ui/lib/config-validator.ts | 8 +- llama_stack/ui/lib/contents-api.ts | 45 ++-- .../ui/lib/format-message-content.test.ts | 28 ++- llama_stack/ui/lib/format-message-content.ts | 4 +- llama_stack/ui/lib/format-tool-call.tsx | 6 +- llama_stack/ui/lib/truncate-text.ts | 2 +- 73 files changed, 1452 insertions(+), 1226 deletions(-) create mode 100644 llama_stack/ui/.nvmrc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30843173c..4309f289a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,7 @@ exclude: 'build/' default_language_version: python: python3.12 + node: "22" repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -145,6 +146,20 @@ repos: pass_filenames: false require_serial: true files: ^.github/workflows/.*$ + - id: ui-prettier + name: Format UI code with Prettier + entry: bash -c 'cd llama_stack/ui && npm run format' + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true + - id: ui-eslint + name: Lint UI code with ESLint + entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet' + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/llama_stack/ui/.nvmrc b/llama_stack/ui/.nvmrc new file mode 100644 index 000000000..1384ff6a1 --- /dev/null +++ b/llama_stack/ui/.nvmrc @@ -0,0 +1 @@ +22.5.1 diff --git a/llama_stack/ui/.prettierignore b/llama_stack/ui/.prettierignore index 1b8ac8894..b737ae6ed 100644 --- a/llama_stack/ui/.prettierignore +++ b/llama_stack/ui/.prettierignore @@ -1,3 +1,12 @@ # Ignore artifacts: build coverage +.next +node_modules +dist +*.lock +*.log + +# Generated files +*.min.js +*.min.css diff --git a/llama_stack/ui/.prettierrc b/llama_stack/ui/.prettierrc index 0967ef424..059475a24 100644 --- a/llama_stack/ui/.prettierrc +++ b/llama_stack/ui/.prettierrc @@ -1 +1,10 @@ -{} +{ + "semi": true, + "trailingComma": "es5", + "singleQuote": false, + "printWidth": 80, + "tabWidth": 2, + "useTabs": false, + "bracketSpacing": true, + "arrowParens": "avoid" +} diff --git a/llama_stack/ui/app/api/v1/[...path]/route.ts b/llama_stack/ui/app/api/v1/[...path]/route.ts index 1959f9099..51c1f8004 100644 --- a/llama_stack/ui/app/api/v1/[...path]/route.ts +++ b/llama_stack/ui/app/api/v1/[...path]/route.ts @@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) { const responseText = await response.text(); console.log( - `Response from FastAPI: ${response.status} ${response.statusText}`, + `Response from FastAPI: ${response.status} ${response.statusText}` ); // Create response with same status and headers @@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) { backend_url: BACKEND_URL, timestamp: new Date().toISOString(), }, - { status: 500 }, + { status: 500 } ); } } diff --git a/llama_stack/ui/app/auth/signin/page.tsx b/llama_stack/ui/app/auth/signin/page.tsx index c9510fd6b..0ccb4a397 100644 --- a/llama_stack/ui/app/auth/signin/page.tsx +++ b/llama_stack/ui/app/auth/signin/page.tsx @@ -51,9 +51,9 @@ export default function SignInPage() { onClick={() => { console.log("Signing in with GitHub..."); signIn("github", { callbackUrl: "/auth/signin" }).catch( - (error) => { + error => { console.error("Sign in error:", error); - }, + } ); }} className="w-full" diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index d8094af85..b8651aca0 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -29,14 +29,13 @@ export default function ChatPlaygroundPage() { const isModelsLoading = modelsLoading ?? true; - useEffect(() => { const fetchModels = async () => { try { setModelsLoading(true); setModelsError(null); const modelList = await client.models.list(); - const llmModels = modelList.filter(model => model.model_type === 'llm'); + const llmModels = modelList.filter(model => model.model_type === "llm"); setModels(llmModels); if (llmModels.length > 0) { setSelectedModel(llmModels[0].identifier); @@ -53,103 +52,122 @@ export default function ChatPlaygroundPage() { }, [client]); const extractTextContent = (content: unknown): string => { - if (typeof content === 'string') { + if (typeof content === "string") { return content; } if (Array.isArray(content)) { return content - .filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text') - .map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '') - .join(''); + .filter( + item => + item && + typeof item === "object" && + "type" in item && + item.type === "text" + ) + .map(item => + item && typeof item === "object" && "text" in item + ? String(item.text) + : "" + ) + .join(""); } - if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) { - return String(content.text) || ''; + if ( + content && + typeof content === "object" && + "type" in content && + content.type === "text" && + "text" in content + ) { + return String(content.text) || ""; } - return ''; + return ""; }; const handleInputChange = (e: React.ChangeEvent) => { setInput(e.target.value); }; -const handleSubmit = async (event?: { preventDefault?: () => void }) => { - event?.preventDefault?.(); - if (!input.trim()) return; + const handleSubmit = async (event?: { preventDefault?: () => void }) => { + event?.preventDefault?.(); + if (!input.trim()) return; - // Add user message to chat - const userMessage: Message = { - id: Date.now().toString(), - role: "user", - content: input.trim(), - createdAt: new Date(), - }; - - setMessages(prev => [...prev, userMessage]); - setInput(""); - - // Use the helper function with the content - await handleSubmitWithContent(userMessage.content); -}; - -const handleSubmitWithContent = async (content: string) => { - setIsGenerating(true); - setError(null); - - try { - const messageParams: CompletionCreateParams["messages"] = [ - ...messages.map(msg => { - const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content); - if (msg.role === "user") { - return { role: "user" as const, content: msgContent }; - } else if (msg.role === "assistant") { - return { role: "assistant" as const, content: msgContent }; - } else { - return { role: "system" as const, content: msgContent }; - } - }), - { role: "user" as const, content } - ]; - - const response = await client.chat.completions.create({ - model: selectedModel, - messages: messageParams, - stream: true, - }); - - const assistantMessage: Message = { - id: (Date.now() + 1).toString(), - role: "assistant", - content: "", + // Add user message to chat + const userMessage: Message = { + id: Date.now().toString(), + role: "user", + content: input.trim(), createdAt: new Date(), }; - setMessages(prev => [...prev, assistantMessage]); - let fullContent = ""; - for await (const chunk of response) { - if (chunk.choices && chunk.choices[0]?.delta?.content) { - const deltaContent = chunk.choices[0].delta.content; - fullContent += deltaContent; + setMessages(prev => [...prev, userMessage]); + setInput(""); - flushSync(() => { - setMessages(prev => { - const newMessages = [...prev]; - const lastMessage = newMessages[newMessages.length - 1]; - if (lastMessage.role === "assistant") { - lastMessage.content = fullContent; - } - return newMessages; + // Use the helper function with the content + await handleSubmitWithContent(userMessage.content); + }; + + const handleSubmitWithContent = async (content: string) => { + setIsGenerating(true); + setError(null); + + try { + const messageParams: CompletionCreateParams["messages"] = [ + ...messages.map(msg => { + const msgContent = + typeof msg.content === "string" + ? msg.content + : extractTextContent(msg.content); + if (msg.role === "user") { + return { role: "user" as const, content: msgContent }; + } else if (msg.role === "assistant") { + return { role: "assistant" as const, content: msgContent }; + } else { + return { role: "system" as const, content: msgContent }; + } + }), + { role: "user" as const, content }, + ]; + + const response = await client.chat.completions.create({ + model: selectedModel, + messages: messageParams, + stream: true, + }); + + const assistantMessage: Message = { + id: (Date.now() + 1).toString(), + role: "assistant", + content: "", + createdAt: new Date(), + }; + + setMessages(prev => [...prev, assistantMessage]); + let fullContent = ""; + for await (const chunk of response) { + if (chunk.choices && chunk.choices[0]?.delta?.content) { + const deltaContent = chunk.choices[0].delta.content; + fullContent += deltaContent; + + flushSync(() => { + setMessages(prev => { + const newMessages = [...prev]; + const lastMessage = newMessages[newMessages.length - 1]; + if (lastMessage.role === "assistant") { + lastMessage.content = fullContent; + } + return newMessages; + }); }); - }); + } } + } catch (err) { + console.error("Error sending message:", err); + setError("Failed to send message. Please try again."); + setMessages(prev => prev.slice(0, -1)); + } finally { + setIsGenerating(false); } - } catch (err) { - console.error("Error sending message:", err); - setError("Failed to send message. Please try again."); - setMessages(prev => prev.slice(0, -1)); - } finally { - setIsGenerating(false); - } -}; + }; const suggestions = [ "Write a Python function that prints 'Hello, World!'", "Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?", @@ -163,7 +181,7 @@ const handleSubmitWithContent = async (content: string) => { content: message.content, createdAt: new Date(), }; - setMessages(prev => [...prev, newMessage]) + setMessages(prev => [...prev, newMessage]); handleSubmitWithContent(newMessage.content); }; @@ -177,12 +195,20 @@ const handleSubmitWithContent = async (content: string) => {

Chat Playground (Completions)

- - + - {models.map((model) => ( + {models.map(model => ( {model.identifier} diff --git a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx index 82aa3496e..e11924f4c 100644 --- a/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx +++ b/llama_stack/ui/app/logs/chat-completions/[id]/page.tsx @@ -33,12 +33,12 @@ export default function ChatCompletionDetailPage() { } catch (err) { console.error( `Error fetching chat completion detail for ID ${id}:`, - err, + err ); setError( err instanceof Error ? err - : new Error("Failed to fetch completion detail"), + : new Error("Failed to fetch completion detail") ); } finally { setIsLoading(false); diff --git a/llama_stack/ui/app/logs/responses/[id]/page.tsx b/llama_stack/ui/app/logs/responses/[id]/page.tsx index 7f4252856..922d35531 100644 --- a/llama_stack/ui/app/logs/responses/[id]/page.tsx +++ b/llama_stack/ui/app/logs/responses/[id]/page.tsx @@ -13,10 +13,10 @@ export default function ResponseDetailPage() { const client = useAuthClient(); const [responseDetail, setResponseDetail] = useState( - null, + null ); const [inputItems, setInputItems] = useState( - null, + null ); const [isLoading, setIsLoading] = useState(true); const [isLoadingInputItems, setIsLoadingInputItems] = useState(true); @@ -25,7 +25,7 @@ export default function ResponseDetailPage() { // Helper function to convert ResponseObject to OpenAIResponse const convertResponseObject = ( - responseData: ResponseObject, + responseData: ResponseObject ): OpenAIResponse => { return { id: responseData.id, @@ -73,12 +73,12 @@ export default function ResponseDetailPage() { } else { console.error( `Error fetching response detail for ID ${id}:`, - responseResult.reason, + responseResult.reason ); setError( responseResult.reason instanceof Error ? responseResult.reason - : new Error("Failed to fetch response detail"), + : new Error("Failed to fetch response detail") ); } @@ -90,18 +90,18 @@ export default function ResponseDetailPage() { } else { console.error( `Error fetching input items for response ID ${id}:`, - inputItemsResult.reason, + inputItemsResult.reason ); setInputItemsError( inputItemsResult.reason instanceof Error ? inputItemsResult.reason - : new Error("Failed to fetch input items"), + : new Error("Failed to fetch input items") ); } } catch (err) { console.error(`Unexpected error fetching data for ID ${id}:`, err); setError( - err instanceof Error ? err : new Error("Unexpected error occurred"), + err instanceof Error ? err : new Error("Unexpected error occurred") ); } finally { setIsLoading(false); diff --git a/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx b/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx index 6896b992a..d58de3085 100644 --- a/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx +++ b/llama_stack/ui/app/logs/vector-stores/[id]/files/[fileId]/contents/[contentId]/page.tsx @@ -18,7 +18,10 @@ import { PropertiesCard, PropertyItem, } from "@/components/layout/detail-layout"; -import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb"; +import { + PageBreadcrumb, + BreadcrumbSegment, +} from "@/components/layout/page-breadcrumb"; export default function ContentDetailPage() { const params = useParams(); @@ -28,13 +31,13 @@ export default function ContentDetailPage() { const contentId = params.contentId as string; const client = useAuthClient(); - const getTextFromContent = (content: any): string => { - if (typeof content === 'string') { + const getTextFromContent = (content: unknown): string => { + if (typeof content === "string") { return content; - } else if (content && content.type === 'text') { + } else if (content && content.type === "text") { return content.text; } - return ''; + return ""; }; const [store, setStore] = useState(null); @@ -44,7 +47,9 @@ export default function ContentDetailPage() { const [error, setError] = useState(null); const [isEditing, setIsEditing] = useState(false); const [editedContent, setEditedContent] = useState(""); - const [editedMetadata, setEditedMetadata] = useState>({}); + const [editedMetadata, setEditedMetadata] = useState>( + {} + ); const [isEditingEmbedding, setIsEditingEmbedding] = useState(false); const [editedEmbedding, setEditedEmbedding] = useState([]); @@ -64,8 +69,13 @@ export default function ContentDetailPage() { setFile(fileResponse as VectorStoreFile); const contentsAPI = new ContentsAPI(client); - const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId); - const targetContent = contentsResponse.data.find(c => c.id === contentId); + const contentsResponse = await contentsAPI.listContents( + vectorStoreId, + fileId + ); + const targetContent = contentsResponse.data.find( + c => c.id === contentId + ); if (targetContent) { setContent(targetContent); @@ -76,7 +86,9 @@ export default function ContentDetailPage() { throw new Error(`Content ${contentId} not found`); } } catch (err) { - setError(err instanceof Error ? err : new Error("Failed to load content.")); + setError( + err instanceof Error ? err : new Error("Failed to load content.") + ); } finally { setIsLoading(false); } @@ -88,7 +100,8 @@ export default function ContentDetailPage() { if (!content) return; try { - const updates: { content?: string; metadata?: Record } = {}; + const updates: { content?: string; metadata?: Record } = + {}; if (editedContent !== getTextFromContent(content.content)) { updates.content = editedContent; @@ -100,25 +113,32 @@ export default function ContentDetailPage() { if (Object.keys(updates).length > 0) { const contentsAPI = new ContentsAPI(client); - const updatedContent = await contentsAPI.updateContent(vectorStoreId, fileId, contentId, updates); + const updatedContent = await contentsAPI.updateContent( + vectorStoreId, + fileId, + contentId, + updates + ); setContent(updatedContent); } setIsEditing(false); } catch (err) { - console.error('Failed to update content:', err); + console.error("Failed to update content:", err); } }; const handleDelete = async () => { - if (!confirm('Are you sure you want to delete this content?')) return; + if (!confirm("Are you sure you want to delete this content?")) return; try { const contentsAPI = new ContentsAPI(client); await contentsAPI.deleteContent(vectorStoreId, fileId, contentId); - router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`); + router.push( + `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` + ); } catch (err) { - console.error('Failed to delete content:', err); + console.error("Failed to delete content:", err); } }; @@ -134,10 +154,19 @@ export default function ContentDetailPage() { const breadcrumbSegments: BreadcrumbSegment[] = [ { label: "Vector Stores", href: "/logs/vector-stores" }, - { label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` }, + { + label: store?.name || vectorStoreId, + href: `/logs/vector-stores/${vectorStoreId}`, + }, { label: "Files", href: `/logs/vector-stores/${vectorStoreId}` }, - { label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` }, - { label: "Contents", href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` }, + { + label: fileId, + href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`, + }, + { + label: "Contents", + href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`, + }, { label: contentId }, ]; @@ -186,7 +215,7 @@ export default function ContentDetailPage() { {isEditing ? (