chore: Rename RagTool FileSearchTool

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-24 16:52:58 -04:00
parent 4566eebe05
commit 2d9163529a
288 changed files with 16985 additions and 2071 deletions

View file

@ -4,5 +4,5 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .rag_tool import *
from .file_search_tool import *
from .tools import *

View file

@ -76,7 +76,7 @@ class RAGDocument(BaseModel):
@json_schema_type
class RAGQueryResult(BaseModel):
class FileSearchResult(BaseModel):
"""Result of a RAG query containing retrieved content and metadata.
:param content: (Optional) The retrieved content from the query
@ -88,7 +88,7 @@ class RAGQueryResult(BaseModel):
@json_schema_type
class RAGQueryGenerator(Enum):
class FileSearchGenerator(Enum):
"""Types of query generators for RAG systems.
:cvar default: Default query generator using simple text processing
@ -102,7 +102,7 @@ class RAGQueryGenerator(Enum):
@json_schema_type
class RAGSearchMode(StrEnum):
class FileSearchMode(StrEnum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching
@ -116,7 +116,7 @@ class RAGSearchMode(StrEnum):
@json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel):
class DefaultFileSearchGeneratorConfig(BaseModel):
"""Configuration for the default RAG query generator.
:param type: Type of query generator, always 'default'
@ -128,8 +128,8 @@ class DefaultRAGQueryGeneratorConfig(BaseModel):
@json_schema_type
class LLMRAGQueryGeneratorConfig(BaseModel):
"""Configuration for the LLM-based RAG query generator.
class LLMFileSearchGeneratorConfig(BaseModel):
"""Configuration for the LLM-based File Search generator.
:param type: Type of query generator, always 'llm'
:param model: Name of the language model to use for query generation
@ -141,15 +141,15 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
template: str
RAGQueryGeneratorConfig = Annotated[
DefaultRAGQueryGeneratorConfig | LLMRAGQueryGeneratorConfig,
FileSearchGeneratorConfig = Annotated[
DefaultFileSearchGeneratorConfig | LLMFileSearchGeneratorConfig,
Field(discriminator="type"),
]
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
register_schema(FileSearchGeneratorConfig, name="FileSearchGeneratorConfig")
@json_schema_type
class RAGQueryConfig(BaseModel):
class FileSearchConfig(BaseModel):
"""
Configuration for the RAG query generation.
@ -165,11 +165,11 @@ class RAGQueryConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
query_generator_config: FileSearchGeneratorConfig = Field(default=DefaultFileSearchGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: RAGSearchMode | None = RAGSearchMode.VECTOR
mode: FileSearchMode | None = FileSearchMode.VECTOR
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")
@ -185,8 +185,8 @@ class RAGQueryConfig(BaseModel):
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
class FileSearchToolRuntime(Protocol):
@webmethod(route="/tool-runtime/file_search-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert(
self,
documents: list[RAGDocument],
@ -201,18 +201,18 @@ class RAGToolRuntime(Protocol):
"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
@webmethod(route="/tool-runtime/file_search-tool/query", method="POST", level=LLAMA_STACK_API_V1)
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
query_config: FileSearchConfig | None = None,
) -> FileSearchResult:
"""Query the RAG system for context; typically invoked by the agent.
:param content: The query content to search for in the indexed documents
:param vector_db_ids: List of vector database IDs to search within
:param query_config: (Optional) Configuration parameters for the query operation
:returns: RAGQueryResult containing the retrieved content and metadata
:returns: FileSearchResult containing the retrieved content and metadata
"""
...

View file

@ -16,7 +16,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
from .file_search_tool import FileSearchToolRuntime
@json_schema_type
@ -195,7 +195,7 @@ class SpecialToolGroup(Enum):
class ToolRuntime(Protocol):
tool_store: ToolStore | None = None
rag_tool: RAGToolRuntime | None = None
rag_tool: FileSearchToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)

View file

@ -28,7 +28,7 @@ class ChunkMetadata(BaseModel):
"""
`ChunkMetadata` is backend metadata for a `Chunk` that is used to store additional information about the chunk that
will not be used in the context during inference, but is required for backend functionality. The `ChunkMetadata`
is set during chunk creation in `MemoryToolRuntimeImpl().insert()`and is not expected to change after.
is set during chunk creation in `FileSearchToolRuntimeImpl().insert()`and is not expected to change after.
Use `Chunk.metadata` for metadata that will be used in the context during inference.
:param chunk_id: The ID of the chunk. If not set, it will be generated based on the document ID and content.
:param document_id: The ID of the document this chunk belongs to.

View file

@ -11,11 +11,11 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.tools import (
FileSearchConfig,
FileSearchResult,
FileSearchToolRuntime,
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.log import get_logger
@ -26,21 +26,21 @@ logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
class FileSearchToolImpl(FileSearchToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
logger.debug("Initializing ToolRuntimeRouter.FileSearchToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_store_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_store_ids}")
query_config: FileSearchConfig | None = None,
) -> FileSearchResult:
logger.debug(f"ToolRuntimeRouter.FileSearchToolImpl.query: {vector_store_ids}")
provider = await self.routing_table.get_provider_impl("knowledge_search")
return await provider.query(content, vector_store_ids, query_config)
@ -51,7 +51,7 @@ class ToolRuntimeRouter(ToolRuntime):
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
f"ToolRuntimeRouter.FileSearchToolImpl.insert: {vector_store_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
provider = await self.routing_table.get_provider_impl("insert_into_memory")
return await provider.insert(documents, vector_store_id, chunk_size_in_tokens)
@ -64,7 +64,7 @@ class ToolRuntimeRouter(ToolRuntime):
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
self.rag_tool = self.FileSearchToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))

View file

@ -18,7 +18,7 @@ logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
# handle the funny case like "builtin::rag/knowledge_search"
# handle the funny case like "builtin::file_search/knowledge_search"
parts = toolgroup_name_with_maybe_tool_name.split("/")
if len(parts) == 2:
return parts[0]

View file

@ -13,7 +13,7 @@ from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.tools import FileSearchToolRuntime, SpecialToolGroup
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
@ -27,7 +27,7 @@ RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
return {
SpecialToolGroup.rag_tool: RAGToolRuntime,
SpecialToolGroup.rag_tool: FileSearchToolRuntime,
}

View file

@ -32,7 +32,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.tools import FileSearchToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
@ -80,7 +80,7 @@ class LlamaStack(
Inspect,
ToolGroups,
ToolRuntime,
RAGToolRuntime,
FileSearchToolRuntime,
Files,
Prompts,
Conversations,

View file

@ -23,7 +23,7 @@ def main():
# Playground pages
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
rag_page = st.Page("page/playground/file_search.py", title="RAG", icon="💬", default=False)
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
# Distribution pages

View file

@ -54,7 +54,7 @@ def tool_chat_page():
help="List of built-in tools from your llama stack server.",
)
if "builtin::rag" in toolgroup_selection:
if "builtin::file_search" in toolgroup_selection:
vector_stores = llama_stack_api.client.vector_stores.list() or []
if not vector_stores:
st.info("No vector databases available for selection.")
@ -115,9 +115,9 @@ def tool_chat_page():
)
for i, tool_name in enumerate(toolgroup_selection):
if tool_name == "builtin::rag":
if tool_name == "builtin::file_search":
tool_dict = dict(
name="builtin::rag",
name="builtin::file_search",
args={
"vector_store_ids": list(selected_vector_stores),
},

View file

@ -48,7 +48,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -216,8 +216,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -263,8 +263,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -26,7 +26,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
image_type: venv
additional_pip_packages:
- aiosqlite

View file

@ -45,7 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="inline::file_search-runtime"),
],
}
name = "dell"
@ -99,8 +99,8 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="brave-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
toolgroup_id="builtin::file_search",
provider_id="file_search-runtime",
),
]

View file

@ -87,8 +87,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
storage:
backends:
kv_default:
@ -133,8 +133,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -83,8 +83,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
storage:
backends:
kv_default:
@ -124,8 +124,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: brave-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -24,7 +24,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -47,7 +47,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="inline::file_search-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -93,8 +93,8 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
toolgroup_id="builtin::file_search",
provider_id="file_search-runtime",
),
]

View file

@ -98,8 +98,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -146,8 +146,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -88,8 +88,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -131,8 +131,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -20,7 +20,7 @@ distribution_spec:
scoring:
- provider_type: inline::basic
tool_runtime:
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
files:
- provider_type: inline::localfs
image_type: venv

View file

@ -28,7 +28,7 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
BuildProvider(provider_type="remote::nvidia"),
],
"scoring": [BuildProvider(provider_type="inline::basic")],
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
"tool_runtime": [BuildProvider(provider_type="inline::file_search-runtime")],
"files": [BuildProvider(provider_type="inline::localfs")],
}
@ -68,8 +68,8 @@ def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
toolgroup_id="builtin::file_search",
provider_id="file_search-runtime",
),
]

View file

@ -81,8 +81,8 @@ providers:
- provider_id: basic
provider_type: inline::basic
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -129,8 +129,8 @@ registered_resources:
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -70,8 +70,8 @@ providers:
- provider_id: basic
provider_type: inline::basic
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@ -108,8 +108,8 @@ registered_resources:
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -28,7 +28,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -118,7 +118,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="inline::file_search-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -155,8 +155,8 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
toolgroup_id="builtin::file_search",
provider_id="file_search-runtime",
),
]

View file

@ -118,8 +118,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -244,8 +244,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -14,7 +14,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
- provider_type: remote::model-context-protocol
image_type: venv
additional_pip_packages:

View file

@ -45,7 +45,7 @@ def get_distribution_template() -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="inline::file_search-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
}
@ -67,8 +67,8 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
toolgroup_id="builtin::file_search",
provider_id="file_search-runtime",
),
]

View file

@ -54,8 +54,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
storage:
@ -107,8 +107,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -49,7 +49,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -219,8 +219,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -266,8 +266,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -49,7 +49,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
- provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference

View file

@ -216,8 +216,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
batches:
@ -263,8 +263,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -141,7 +141,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="inline::file_search-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"batches": [
@ -164,8 +164,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
toolgroup_id="builtin::file_search",
provider_id="file_search-runtime",
),
]
default_shields = [

View file

@ -23,7 +23,7 @@ distribution_spec:
tool_runtime:
- provider_type: remote::brave-search
- provider_type: remote::tavily-search
- provider_type: inline::rag-runtime
- provider_type: inline::file_search-runtime
- provider_type: remote::model-context-protocol
files:
- provider_type: inline::localfs

View file

@ -83,8 +83,8 @@ providers:
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: file_search-runtime
provider_type: inline::file_search-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
@ -125,8 +125,8 @@ registered_resources:
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::file_search
provider_id: file_search-runtime
server:
port: 8321
telemetry:

View file

@ -33,7 +33,7 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
"tool_runtime": [
BuildProvider(provider_type="remote::brave-search"),
BuildProvider(provider_type="remote::tavily-search"),
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="inline::file_search-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
@ -51,8 +51,8 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
toolgroup_id="builtin::file_search",
provider_id="file_search-runtime",
),
]

View file

@ -86,7 +86,7 @@ from .safety import SafetyException, ShieldRunnerMixin
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
RAG_TOOL_GROUP = "builtin::file_search"
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -927,14 +927,14 @@ class ChatAgent(ShieldRunnerMixin):
"""Parse a toolgroup name into its components.
Args:
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
toolgroup_name: The toolgroup name to parse (e.g. "builtin::file_search/knowledge_search")
Returns:
A tuple of (tool_type, tool_group, tool_name)
"""
split_names = toolgroup_name_with_maybe_tool_name.split("/")
if len(split_names) == 2:
# e.g. "builtin::rag"
# e.g. "builtin::file_search"
tool_group, tool_name = split_names
else:
tool_group, tool_name = split_names[0], None

View file

@ -12,8 +12,8 @@ from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
from .file_search import FileSearchToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
impl = FileSearchToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize()
return impl

View file

@ -9,19 +9,19 @@ from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
from llama_stack.apis.tools.file_search_tool import (
DefaultFileSearchGeneratorConfig,
FileSearchGenerator,
FileSearchGeneratorConfig,
LLMFileSearchGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
config: RAGQueryGeneratorConfig,
async def generate_file_search_query(
config: FileSearchGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
@ -29,25 +29,25 @@ async def generate_rag_query(
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, content, **kwargs)
if config.type == FileSearchGenerator.default.value:
query = await default_file_search_query_generator(config, content, **kwargs)
elif config.type == FileSearchGenerator.llm.value:
query = await llm_file_search_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
config: DefaultRAGQueryGeneratorConfig,
async def default_file_search_query_generator(
config: DefaultFileSearchGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
config: LLMRAGQueryGeneratorConfig,
async def llm_file_search_query_generator(
config: LLMFileSearchGeneratorConfig,
content: InterleavedContent,
**kwargs,
):

View file

@ -23,11 +23,11 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
FileSearchConfig,
FileSearchResult,
FileSearchToolRuntime,
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolGroup,
ToolInvocationResult,
@ -45,7 +45,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
from .context_retriever import generate_file_search_query
log = get_logger(name=__name__, category="tool_runtime")
@ -91,7 +91,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
class FileSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, FileSearchToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
@ -177,15 +177,15 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
query_config: FileSearchConfig | None = None,
) -> FileSearchResult:
if not vector_db_ids:
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config = query_config or FileSearchConfig()
query = await generate_file_search_query(
query_config.query_generator_config,
content,
inference_api=self.inference_api,
@ -218,7 +218,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
scores.append(score)
if not chunks:
return RAGQueryResult(content=None)
return FileSearchResult(content=None)
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
@ -269,7 +269,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
)
return RAGQueryResult(
return FileSearchResult(
content=picked,
metadata={
"document_ids": [c.document_id for c in chunks[: len(picked)]],
@ -312,9 +312,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
query_config = TypeAdapter(FileSearchConfig).validate_python(query_config)
else:
query_config = RAGQueryConfig()
query_config = FileSearchConfig()
query = kwargs["query"]
result = await self.query(

View file

@ -18,7 +18,7 @@ def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
provider_type="inline::file_search-runtime",
pip_packages=DEFAULT_VECTOR_IO_DEPS
+ [
"tqdm",
@ -29,8 +29,8 @@ def available_providers() -> list[ProviderSpec]:
"sentencepiece",
"transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
module="llama_stack.providers.inline.tool_runtime.file_search",
config_class="llama_stack.providers.inline.tool_runtime.file_search.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),

View file

@ -241,33 +241,33 @@ Two ranker types are supported:
- alpha=1: Only use vector scores
- alpha=0.5: Equal weight to both (default)
Example using RAGQueryConfig with different search modes:
Example using FileSearchConfig with different search modes:
```python
from llama_stack.apis.tools import RAGQueryConfig, RRFRanker, WeightedRanker
from llama_stack.apis.tools import FileSearchConfig, RRFRanker, WeightedRanker
# Vector search
config = RAGQueryConfig(mode="vector", max_chunks=5)
config = FileSearchConfig(mode="vector", max_chunks=5)
# Keyword search
config = RAGQueryConfig(mode="keyword", max_chunks=5)
config = FileSearchConfig(mode="keyword", max_chunks=5)
# Hybrid search with custom RRF ranker
config = RAGQueryConfig(
config = FileSearchConfig(
mode="hybrid",
max_chunks=5,
ranker=RRFRanker(impact_factor=50.0), # Custom impact factor
)
# Hybrid search with weighted ranker
config = RAGQueryConfig(
config = FileSearchConfig(
mode="hybrid",
max_chunks=5,
ranker=WeightedRanker(alpha=0.7), # 70% vector, 30% keyword
)
# Hybrid search with default RRF ranker
config = RAGQueryConfig(
config = FileSearchConfig(
mode="hybrid", max_chunks=5
) # Will use RRF with impact_factor=60.0
```

View file

@ -144,7 +144,7 @@ const mockModels = [
const mockToolgroups = [
{
identifier: "builtin::rag",
identifier: "builtin::file_search",
provider_id: "test-provider",
type: "tool_group",
provider_resource_id: "test-resource",
@ -171,7 +171,7 @@ describe("ChatPlaygroundPage", () => {
mockClient.agents.retrieve.mockResolvedValue({
agent_id: "test-agent",
agent_config: {
toolgroups: ["builtin::rag"],
toolgroups: ["builtin::file_search"],
instructions: "Test instructions",
model: "test-model",
},
@ -629,7 +629,7 @@ describe("ChatPlaygroundPage", () => {
agent_config: {
toolgroups: [
{
name: "builtin::rag/knowledge_search",
name: "builtin::file_search/knowledge_search",
args: { vector_db_ids: ["test-vector-db"] },
},
],
@ -664,7 +664,7 @@ describe("ChatPlaygroundPage", () => {
agent_config: {
toolgroups: [
{
name: "builtin::rag/knowledge_search",
name: "builtin::file_search/knowledge_search",
args: { vector_db_ids: ["test-vector-db"] },
},
],

View file

@ -433,7 +433,7 @@ export default function ChatPlaygroundPage() {
) => {
try {
const processedToolgroups = toolgroups.map(toolgroup => {
if (toolgroup === "builtin::rag" && vectorDBs.length > 0) {
if (toolgroup === "builtin::file_search" && vectorDBs.length > 0) {
return {
name: "builtin::rag/knowledge_search",
args: {
@ -1167,7 +1167,7 @@ export default function ChatPlaygroundPage() {
// find RAG toolgroups that have vector_db_ids configured
const ragToolgroups = selectedAgentConfig.toolgroups.filter(toolgroup => {
if (typeof toolgroup === "object" && toolgroup.name?.includes("rag")) {
if (typeof toolgroup === "object" && toolgroup.name?.includes("file_search")) {
return toolgroup.args && "vector_db_ids" in toolgroup.args;
}
return false;
@ -1505,7 +1505,7 @@ export default function ChatPlaygroundPage() {
const toolArgs =
typeof toolgroup === "object" ? toolgroup.args : null;
const isRAGTool = toolName.includes("rag");
const isRAGTool = toolName.includes("file_search");
const displayName = isRAGTool ? "RAG Search" : toolName;
const displayIcon = isRAGTool
? "🔍"
@ -1761,7 +1761,7 @@ export default function ChatPlaygroundPage() {
</div>
{/* Vector DB Configuration for RAG */}
{selectedToolgroups.includes("builtin::rag") && (
{selectedToolgroups.includes("builtin::file_search") && (
<div>
<label className="text-sm font-medium block mb-2">
Vector Databases for RAG
@ -1825,7 +1825,7 @@ export default function ChatPlaygroundPage() {
)}
</div>
{selectedVectorDBs.length === 0 &&
selectedToolgroups.includes("builtin::rag") && (
selectedToolgroups.includes("builtin::file_search") && (
<p className="text-xs text-muted-foreground mt-1">
RAG tool selected but no vector databases chosen.
Create or select a vector database.