From 7cd1c2c238969cb41ddb3d87b7dcfe5331350be1 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Sat, 6 Sep 2025 07:26:34 -0600 Subject: [PATCH 1/5] feat: Updating Rag Tool to use Files API and Vector Stores API (#3344) --- docs/source/getting_started/demo_script.py | 5 +- .../inline/tool_runtime/rag/__init__.py | 2 +- .../inline/tool_runtime/rag/memory.py | 74 ++++++++++++++----- .../providers/registry/tool_runtime.py | 2 +- .../integration/tool_runtime/test_rag_tool.py | 41 ++++++---- tests/unit/rag/test_rag_query.py | 8 +- 6 files changed, 93 insertions(+), 39 deletions(-) diff --git a/docs/source/getting_started/demo_script.py b/docs/source/getting_started/demo_script.py index 777fc78c2..2ea67739f 100644 --- a/docs/source/getting_started/demo_script.py +++ b/docs/source/getting_started/demo_script.py @@ -18,12 +18,13 @@ embedding_model_id = ( ).identifier embedding_dimension = em.metadata["embedding_dimension"] -_ = client.vector_dbs.register( +vector_db = client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, provider_id="faiss", ) +vector_db_id = vector_db.identifier source = "https://www.paulgraham.com/greatwork.html" print("rag_tool> Ingesting document:", source) document = RAGDocument( @@ -35,7 +36,7 @@ document = RAGDocument( client.tool_runtime.rag_tool.insert( documents=[document], vector_db_id=vector_db_id, - chunk_size_in_tokens=50, + chunk_size_in_tokens=100, ) agent = Agent( client, diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py index f9a6e5c55..f9a7e7b89 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): from .memory import MemoryToolRuntimeImpl - impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a1543457b..cb526e8ee 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -5,10 +5,15 @@ # the root directory of this source tree. import asyncio +import base64 +import io +import mimetypes import secrets import string from typing import Any +import httpx +from fastapi import UploadFile from pydantic import TypeAdapter from llama_stack.apis.common.content_types import ( @@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContentItem, TextContentItem, ) +from llama_stack.apis.files import Files, OpenAIFilePurpose from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( ListToolDefsResponse, @@ -30,13 +36,18 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) -from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO +from llama_stack.apis.vector_io import ( + QueryChunksResponse, + VectorIO, + VectorStoreChunkingStrategyStatic, + VectorStoreChunkingStrategyStaticConfig, +) from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, - make_overlapped_chunks, + parse_data_url, ) from .config import RagToolRuntimeConfig @@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti config: RagToolRuntimeConfig, vector_io_api: VectorIO, inference_api: Inference, + files_api: Files, ): self.config = config self.vector_io_api = vector_io_api self.inference_api = inference_api + self.files_api = files_api async def initialize(self): pass @@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - chunks = [] + if not documents: + return + for doc in documents: - content = await content_from_doc(doc) - # TODO: we should add enrichment here as URLs won't be added to the metadata by default - chunks.extend( - make_overlapped_chunks( - doc.document_id, - content, - chunk_size_in_tokens, - chunk_size_in_tokens // 4, - doc.metadata, + if isinstance(doc.content, URL): + if doc.content.uri.startswith("data:"): + parts = parse_data_url(doc.content.uri) + file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode() + mime_type = parts["mimetype"] + else: + async with httpx.AsyncClient() as client: + response = await client.get(doc.content.uri) + file_data = response.content + mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream") + else: + content_str = await content_from_doc(doc) + file_data = content_str.encode("utf-8") + mime_type = doc.mime_type or "text/plain" + + file_extension = mimetypes.guess_extension(mime_type) or ".txt" + filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") + + file_obj = io.BytesIO(file_data) + file_obj.name = filename + + upload_file = UploadFile(file=file_obj, filename=filename) + + created_file = await self.files_api.openai_upload_file( + file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS + ) + + chunking_strategy = VectorStoreChunkingStrategyStatic( + static=VectorStoreChunkingStrategyStaticConfig( + max_chunk_size_tokens=chunk_size_in_tokens, + chunk_overlap_tokens=chunk_size_in_tokens // 4, ) ) - if not chunks: - return - - await self.vector_io_api.insert_chunks( - chunks=chunks, - vector_db_id=vector_db_id, - ) + await self.vector_io_api.openai_attach_file_to_vector_store( + vector_store_id=vector_db_id, + file_id=created_file.id, + attributes=doc.metadata, + chunking_strategy=chunking_strategy, + ) async def query( self, diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 661851443..5a58fa7af 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -32,7 +32,7 @@ def available_providers() -> list[ProviderSpec]: ], module="llama_stack.providers.inline.tool_runtime.rag", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", - api_dependencies=[Api.vector_io, Api.inference], + api_dependencies=[Api.vector_io, Api.inference, Api.files], description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", ), remote_provider_spec( diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index 2affe2a2d..b208500d8 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -17,10 +17,14 @@ def client_with_empty_registry(client_with_models): client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id) clear_registry() + + try: + client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime") + except Exception: + pass + yield client_with_models - # you must clean after the last test if you were running tests against - # a stateful server instance clear_registry() @@ -66,12 +70,13 @@ def assert_valid_text_response(response): def test_vector_db_insert_inline_and_query( client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension ): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + vector_db = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + vector_db_id = vector_db.identifier client_with_empty_registry.tool_runtime.rag_tool.insert( documents=sample_documents, @@ -134,7 +139,11 @@ def test_vector_db_insert_from_url_and_query( # list to check memory bank is successfully registered available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_db_id in available_vector_dbs + # VectorDB is being migrated to VectorStore, so the ID will be different + # Just check that at least one vector DB was registered + assert len(available_vector_dbs) > 0 + # Use the actual registered vector_db_id for subsequent operations + actual_vector_db_id = available_vector_dbs[0] urls = [ "memory_optimizations.rst", @@ -153,13 +162,13 @@ def test_vector_db_insert_from_url_and_query( client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunk_size_in_tokens=512, ) # Query for the name of method response1 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="What's the name of the fine-tunning method used?", ) assert_valid_chunk_response(response1) @@ -167,7 +176,7 @@ def test_vector_db_insert_from_url_and_query( # Query for the name of model response2 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="Which Llama model is mentioned?", ) assert_valid_chunk_response(response2) @@ -187,7 +196,11 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i ) available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_db_id in available_vector_dbs + # VectorDB is being migrated to VectorStore, so the ID will be different + # Just check that at least one vector DB was registered + assert len(available_vector_dbs) > 0 + # Use the actual registered vector_db_id for subsequent operations + actual_vector_db_id = available_vector_dbs[0] urls = [ "memory_optimizations.rst", @@ -206,19 +219,19 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunk_size_in_tokens=512, ) response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", ) assert_valid_text_response(response_with_metadata) assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "include_metadata_in_content": True, @@ -230,7 +243,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i with pytest.raises((ValueError, BadRequestError)): client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "chunk_template": "This should raise a ValueError because it is missing the proper template variables", diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index 05ccecb99..d18d90716 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti class TestRagQuery: async def test_query_raises_on_empty_vector_db_ids(self): - rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) with pytest.raises(ValueError): await rag_tool.query(content=MagicMock(), vector_db_ids=[]) async def test_query_chunk_metadata_handling(self): - rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) content = "test query content" vector_db_ids = ["db1"] From d6c3b363904eedd9c5323594603dc4f2d5b581eb Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Sep 2025 15:22:20 -0400 Subject: [PATCH 2/5] chore: update the gemini inference impl to use openai-python for openai-compat functions (#3351) # What does this PR do? update the Gemini inference provider to use openai-python for the openai-compat endpoints partially addresses #3349, does not address /inference/completion or /inference/chat-completion ## Test Plan ci --- llama_stack/providers/registry/inference.py | 2 +- llama_stack/providers/remote/inference/gemini/gemini.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 50956f58c..1bb3c3147 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -207,7 +207,7 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="gemini", - pip_packages=["litellm"], + pip_packages=["litellm", "openai"], module="llama_stack.providers.remote.inference.gemini", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index b6048eff7..569227fdd 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -5,12 +5,13 @@ # the root directory of this source tree. from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import GeminiConfig from .models import MODEL_ENTRIES -class GeminiInferenceAdapter(LiteLLMOpenAIMixin): +class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: GeminiConfig) -> None: LiteLLMOpenAIMixin.__init__( self, @@ -21,6 +22,11 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self): + return "https://generativelanguage.googleapis.com/v1beta/openai/" + async def initialize(self) -> None: await super().initialize() From 4c28544c04ab96c5c4d188dda3a53dc2eab0415d Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Sep 2025 15:22:44 -0400 Subject: [PATCH 3/5] chore(gemini, tests): add skips for n and completions, gemini api does not support them (#3350) # What does this PR do? the gemini api endpoints do not support the n param or completions ## Test Plan ci --- tests/integration/inference/test_openai_completion.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index bb447b3c1..099032578 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -37,6 +37,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) "remote::sambanova", "remote::tgi", "remote::vertexai", + "remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404 ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") @@ -63,6 +64,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id): if provider.provider_type in ( "remote::sambanova", "remote::ollama", + # Error code: 400 - [{'error': {'code': 400, 'message': 'Only one candidate can be specified in the + # current model', 'status': 'INVALID_ARGUMENT'}}] + "remote::gemini", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") From bf02cd846fdc39db80291746e06ca547e5afdbdb Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Sep 2025 15:25:13 -0400 Subject: [PATCH 4/5] chore: update the sambanova inference impl to use openai-python for openai-compat functions (#3345) # What does this PR do? update SambaNova inference provider to use OpenAIMixin for openai-compat endpoints ## Test Plan ``` $ SAMBANOVA_API_KEY=... uv run llama stack build --image-type venv --providers inference=remote::sambanova --run ... $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run --group test pytest -v -ra --text-model sambanova/Meta-Llama-3.3-70B-Instruct tests/integration/inference -k 'not store' ... FAILED tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=sambanova/Meta-Llama-3.3-70B-Instruct-inference:chat_completion:tool_calling_tools_absent-True] - AttributeError: 'NoneType' object has no attribute 'delta' FAILED tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=sambanova/Meta-Llama-3.3-70B-Instruct-inference:chat_completion:tool_calling_tools_absent-False] - llama_stack_client.InternalServerError: Error code: 500 - {'detail': 'Internal server error: An une... =========== 2 failed, 16 passed, 68 skipped, 8 deselected, 3 xfailed, 13 warnings in 15.85s ============ ``` the two failures also exist before this change. they are part of the deprecated inference.chat_completion tests that flow through litellm. they can be resolved later. --- llama_stack/providers/registry/inference.py | 2 +- .../remote/inference/sambanova/sambanova.py | 26 ++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 1bb3c3147..7a95fd089 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -270,7 +270,7 @@ Available Models: api=Api.inference, adapter=AdapterSpec( adapter_type="sambanova", - pip_packages=["litellm"], + pip_packages=["litellm", "openai"], module="llama_stack.providers.remote.inference.sambanova", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 96469acac..ee3b0f648 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -4,13 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES -class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): +class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + """ + SambaNova Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability() + is used instead of LiteLLMOpenAIMixin.check_model_availability(). + + - OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists + - LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM + """ + def __init__(self, config: SambaNovaImplConfig): self.config = config self.environment_available_models = [] @@ -24,3 +37,14 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): download_images=True, # SambaNova requires base64 image encoding json_schema_strict=False, # SambaNova doesn't support strict=True yet ) + + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: + """ + Get the base URL for OpenAI mixin. + + :return: The SambaNova base URL + """ + return self.config.url From 9252d9fc018487ef7bd6ac400ed329cd5db1e8c4 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Sep 2025 15:35:30 -0400 Subject: [PATCH 5/5] chore(groq test): skip with_n tests for groq, it is not supported server-side (#3346) # What does this PR do? skip the with_n test for groq, because it isn't supported by the provider's service see https://console.groq.com/docs/openai#currently-unsupported-openai-features Co-authored-by: raghotham --- tests/integration/inference/test_openai_completion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 099032578..2043f6aeb 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -64,6 +64,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id): if provider.provider_type in ( "remote::sambanova", "remote::ollama", + # https://console.groq.com/docs/openai#currently-unsupported-openai-features + # -> Error code: 400 - {'error': {'message': "'n' : number must be at most 1", 'type': 'invalid_request_error'}} + "remote::groq", # Error code: 400 - [{'error': {'code': 400, 'message': 'Only one candidate can be specified in the # current model', 'status': 'INVALID_ARGUMENT'}}] "remote::gemini",